diff --git a/.pylintrc b/.pylintrc
index a1ccb0e..616c20a 100644
--- a/.pylintrc
+++ b/.pylintrc
@@ -82,7 +82,7 @@ persistent=yes
# Minimum Python version to use for version dependent checks. Will default to
# the version used to run pylint.
-py-version=3.7
+py-version=3.8
# Discover python modules and packages in the file system subtree.
recursive=no
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 895d99f..e2a2939 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -9,10 +9,13 @@ If you would like to implement a new feature or a bug, please make sure you (or
### Creating a Pull Request
1. [Fork](https://github.com/cnellington/Contextualized/fork) this repository.
-2. Make your code changes locally.
-3. Check the style using pylint and black following [these steps](https://github.com/cnellington/Contextualized/pull/111#issue-1323230194).
-4. (Optional) Include your name in alphabetical order in [ACKNOWLEDGEMENTS.md](https://github.com/cnellington/Contextualized/blob/main/ACKNOWLEDGEMENTS.md).
-5. Issue a PR to merge your changes into the `dev` branch.
+2. Install locally with `pip install -e .`.
+3. Install extra developer dependencies with `pip install -r dev_requirements.txt`.
+4. Make your code changes locally.
+5. Automatically format your code and check for style issues by running `format_style.sh`. We are working on linting the entire repo, but please make sure your code is cleared by pylint.
+6. Automatically update our documentation by running `update_docs.sh`.
+7. (Optional) Include your name in alphabetical order in [ACKNOWLEDGEMENTS.md](https://github.com/cnellington/Contextualized/blob/main/ACKNOWLEDGEMENTS.md).
+8. Issue a PR to merge your changes into the `main` branch.
## Issues
diff --git a/README.md b/README.md
index c38021c..98e7787 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,4 @@
-![Preview](contextualized_logo.png)
+![Preview](docs/logo.png)
#
![License](https://img.shields.io/github/license/cnellington/contextualized.svg?style=flat-square)
@@ -10,7 +10,7 @@
-A statistical machine learning toolbox for estimating models, distributions, and functions with context-specific parameters.
+An easy-to-use machine learning toolbox for estimating models, distributions, and functions with context-specific parameters.
Context-specific parameters:
- Find hidden heterogeneity in data -- are all samples the same?
@@ -66,13 +66,16 @@ Feel free to add your own page(s) by sending a PR or request an improvement by c
-ContextualizedML was originally implemented by [Caleb Ellington](https://calebellington.com/) (CMU) and [Ben Lengerich](http://web.mit.edu/~blengeri/www) (MIT).
+Contextualized ML was originally implemented by [Caleb Ellington](https://calebellington.com/) (CMU) and [Ben Lengerich](http://web.mit.edu/~blengeri/www) (MIT).
Many people have helped. Check out [ACKNOWLEDGEMENTS.md](https://github.com/cnellington/Contextualized/blob/main/ACKNOWLEDGEMENTS.md)!
## Related Publications and Pre-prints
+- [Contextualized Machine Learning](https://arxiv.org/abs/2310.11340)
+- [Contextualized Networks Reveal Heterogeneous Transcriptomic Regulation in Tumors at Sample-Specific Resolution](https://www.biorxiv.org/content/10.1101/2023.12.01.569658v1)
+- [Contextualized Policy Recovery: Modeling and Interpreting Medical Decisions with Adaptive Imitation Learning](https://arxiv.org/abs/2310.07918)
- [Automated Interpretable Discovery of Heterogeneous Treatment Effectiveness: A COVID-19 Case Study](https://www.sciencedirect.com/science/article/pii/S1532046422001022)
- [NOTMAD: Estimating Bayesian Networks with Sample-Specific Structures and Parameters](http://arxiv.org/abs/2111.01104)
- [Discriminative Subtyping of Lung Cancers from Histopathology Images via Contextual Deep Learning](https://www.medrxiv.org/content/10.1101/2020.06.25.20140053v1.abstract)
diff --git a/contextualized/analysis/__init__.py b/contextualized/analysis/__init__.py
index b08ac8e..0812ce8 100644
--- a/contextualized/analysis/__init__.py
+++ b/contextualized/analysis/__init__.py
@@ -12,3 +12,8 @@
plot_homogeneous_predictor_effects,
plot_heterogeneous_predictor_effects,
)
+from contextualized.analysis.pvals import (
+ calc_homogeneous_context_effects_pvals,
+ calc_homogeneous_predictor_effects_pvals,
+ calc_heterogeneous_predictor_effects_pvals,
+)
diff --git a/contextualized/analysis/accuracy_split.py b/contextualized/analysis/accuracy_split.py
index 06e7cdd..8bdb5fc 100644
--- a/contextualized/analysis/accuracy_split.py
+++ b/contextualized/analysis/accuracy_split.py
@@ -1,8 +1,10 @@
"""
Utilities for post-hoc analysis of trained Contextualized models.
"""
+from typing import *
import numpy as np
+import pandas as pd
from sklearn.metrics import roc_auc_score as roc
@@ -14,11 +16,25 @@ def get_roc(Y_true, Y_pred):
return np.nan
-def print_acc_by_covars(Y_true, Y_pred, covar_df, **kwargs):
+def print_acc_by_covars(
+ Y_true: np.ndarray, Y_pred: np.ndarray, covar_df: pd.DataFrame, **kwargs
+) -> None:
"""
Prints Accuracy for different data splits with covariates.
- Assume Y_true and Y_pred are np arrays.
- Allows train_idx and test_idx as Boolean locators.
+
+ Args:
+ Y_true (np.ndarray): True labels.
+ Y_pred (np.ndarray): Predicted labels.
+ covar_df (pd.DataFrame): DataFrame of covariates.
+ max_classes (int, optional): Maximum number of classes to print. Defaults to 20.
+ covar_stds (np.ndarray, optional): Standard deviations of covariates. Defaults to None.
+ covar_means (np.ndarray, optional): Means of covariates. Defaults to None.
+ covar_encoders (List[LabelEncoder], optional): Encoders for covariates. Defaults to None.
+ train_idx (np.ndarray, optional): Boolean array indicating training data. Defaults to None.
+ test_idx (np.ndarray, optional): Boolean array indicating testing data. Defaults to None.
+
+ Returns:
+ None
"""
Y_true = np.squeeze(Y_true)
Y_pred = np.squeeze(Y_pred)
diff --git a/contextualized/analysis/bootstraps.py b/contextualized/analysis/bootstraps.py
index 6204fc9..a9e2eba 100644
--- a/contextualized/analysis/bootstraps.py
+++ b/contextualized/analysis/bootstraps.py
@@ -1,5 +1,6 @@
# Utility functions for bootstraps
+
def select_good_bootstraps(sklearn_wrapper, train_errs, tol=2, **kwargs):
"""
Select bootstraps that are good for a given model.
@@ -19,5 +20,6 @@ def select_good_bootstraps(sklearn_wrapper, train_errs, tol=2, **kwargs):
train_errs_by_bootstrap = np.mean(train_errs, axis=(1, 2))
sklearn_wrapper.models = sklearn_wrapper.models[
- train_errs_by_bootstrap < tol*np.min(train_errs_by_bootstrap)]
+ train_errs_by_bootstrap < tol * np.min(train_errs_by_bootstrap)
+ ]
return sklearn_wrapper
diff --git a/contextualized/analysis/effects.py b/contextualized/analysis/effects.py
index 6c15769..0787b6e 100644
--- a/contextualized/analysis/effects.py
+++ b/contextualized/analysis/effects.py
@@ -1,21 +1,29 @@
"""
Utilities for plotting effects learned by Contextualized models.
"""
-
+from typing import *
import numpy as np
import matplotlib.pyplot as plt
+from contextualized.easy.wrappers import SKLearnWrapper
+
def simple_plot(
- x_vals,
- y_vals,
+ x_vals: List[Union[float, int]],
+ y_vals: List[Union[float, int]],
**kwargs,
-):
+) -> None:
"""
- Simple plotting of xs and ys with kwargs passed to mpl helpers.
- :param x_vals:
- :param y_vals:
+ Simple plotting of y vs x with kwargs passed to matplotlib helpers.
+
+ Args:
+ x_vals: x values to plot
+ y_vals: y values to plot
+ **kwargs: kwargs passed to matplotlib helpers (fill_alpha, fill_color, y_lowers, y_uppers, x_label, y_label, x_ticks, x_ticklabels, y_ticks, y_ticklabels)
+
+ Returns:
+ None
"""
plt.figure(figsize=kwargs.get("figsize", (8, 8)))
if "y_lowers" in kwargs and "y_uppers" in kwargs:
@@ -84,16 +92,25 @@ def plot_effect(x_vals, y_means, y_lowers=None, y_uppers=None, **kwargs):
)
-def get_homogeneous_context_effects(model, C, **kwargs):
+def get_homogeneous_context_effects(
+ model: SKLearnWrapper, C: np.ndarray, **kwargs
+) -> Tuple[np.ndarray, np.ndarray]:
"""
Get the homogeneous (context-invariant) effects of context.
- :param model:
- :param C:
- returns:
- c_vis: the context values that were used to estimate the effects
- effects: np array of effects, one for each context. Each homogeneous effect is a matrix of shape:
- (n_bootstraps, n_context_vals, n_outcomes).
+ Args:
+ model (SKLearnWrapper): a fitted ``contextualized.easy`` model
+ C: the context values to use to estimate the effects
+ verbose (bool, optional): print progess. Default True.
+ individual_preds (bool, optional): whether to use plot each bootstrap. Default True.
+ C_vis (np.ndarray, optional): Context bins used to visualize context (n_vis, n_contexts). Default None to construct anew.
+ n_vis (int, optional): Number of bins to use to visualize context. Default 1000.
+
+ Returns:
+ Tuple[np.ndarray, np.ndarray]:
+ c_vis: the context values that were used to estimate the effects
+ effects: array of effects, one for each context. Each homogeneous effect is a matrix of shape:
+ (n_bootstraps, n_context_vals, n_outcomes).
"""
if kwargs.get("verbose", True):
print("Estimating Homogeneous Contextual Effects.")
@@ -233,14 +250,32 @@ def plot_boolean_vars(names, y_mean, y_err, **kwargs):
def plot_homogeneous_context_effects(
- model,
- C,
+ model: SKLearnWrapper,
+ C: np.ndarray,
**kwargs,
-):
+) -> None:
"""
- Plot the homogeneous (context-invariant) effects of context.
- :param model:
- :param C:
+ Plot the direct effect of context on outcomes, disregarding other features.
+ This context effect is homogeneous in that it is a static function of context (context-invariant).
+
+ Args:
+ model (SKLearnWrapper): a fitted ``contextualized.easy`` model
+ C: the context values to use to estimate the effects
+ verbose (bool, optional): print progess. Default True.
+ individual_preds (bool, optional): whether to use plot each bootstrap. Default True.
+ C_vis (np.ndarray, optional): Context bins used to visualize context (n_vis, n_contexts). Default None to construct anew.
+ n_vis (int, optional): Number of bins to use to visualize context. Default 1000.
+ lower_pct (int, optional): Lower percentile for bootstraps. Default 2.5.
+ upper_pct (int, optional): Upper percentile for bootstraps. Default 97.5.
+ classification (bool, optional): Whether to exponentiate the effects. Default True.
+ C_encoders (List[sklearn.preprocessing.LabelEncoder], optional): encoders for each context. Default None.
+ C_means (np.ndarray, optional): means for each context. Default None.
+ C_stds (np.ndarray, optional): standard deviations for each context. Default None.
+ xlabel_prefix (str, optional): prefix for x label. Default "".
+ figname (str, optional): name of figure to save. Default None.
+
+ Returns:
+ None
"""
c_vis, effects = get_homogeneous_context_effects(model, C, **kwargs)
# effects.shape is (n_context, n_bootstraps, n_context_vals, n_outcomes)
@@ -283,16 +318,34 @@ def plot_homogeneous_context_effects(
def plot_homogeneous_predictor_effects(
- model,
- C,
- X,
+ model: SKLearnWrapper,
+ C: np.ndarray,
+ X: np.ndarray,
**kwargs,
-):
+) -> None:
"""
- Plot the homogeneous (context-invariant) effects of predictors.
- :param model:
- :param C:
- :param X:
+ Plot the effect of predictors on outcomes that do not change with context (homogeneous).
+
+ Args:
+ model (SKLearnWrapper): a fitted ``contextualized.easy`` model
+ C: the context values to use to estimate the effects
+ X: the predictor values to use to estimate the effects
+ max_classes_for_discrete (int, optional): maximum number of classes to treat as discrete. Default 10.
+ min_effect_size (float, optional): minimum effect size to plot. Default 0.003.
+ ylabel (str, optional): y label for plot. Default "Influence of ".
+ xlabel_prefix (str, optional): prefix for x label. Default "".
+ X_names (List[str], optional): names of predictors. Default None.
+ X_encoders (List[sklearn.preprocessing.LabelEncoder], optional): encoders for each predictor. Default None.
+ X_means (np.ndarray, optional): means for each predictor. Default None.
+ X_stds (np.ndarray, optional): standard deviations for each predictor. Default None.
+ verbose (bool, optional): print progess. Default True.
+ lower_pct (int, optional): Lower percentile for bootstraps. Default 2.5.
+ upper_pct (int, optional): Upper percentile for bootstraps. Default 97.5.
+ classification (bool, optional): Whether to exponentiate the effects. Default True.
+ figname (str, optional): name of figure to save. Default None.
+
+ Returns:
+ None
"""
c_vis = np.zeros_like(C.values)
x_vis = make_grid_mat(X.values, 1000)
@@ -355,19 +408,31 @@ def plot_homogeneous_predictor_effects(
def plot_heterogeneous_predictor_effects(model, C, X, **kwargs):
"""
- Plot the heterogeneous (context-dependent) effects of context.
- :param model:
- :param C:
- :param X:
- :param encoders:
- :param C_means:
- :param C_stds:
- :param X_names:
- :param ylabel: (Default value = "Influence of ")
- :param min_effect_size: (Default value = 0.003)
- :param n_vis: (Default value = 1000)
- :param max_classes_for_discrete: (Default value = 10)
-
+ Plot how the effect of predictors on outcomes changes with context (heterogeneous).
+
+ Args:
+ model (SKLearnWrapper): a fitted ``contextualized.easy`` model
+ C: the context values to use to estimate the effects
+ X: the predictor values to use to estimate the effects
+ max_classes_for_discrete (int, optional): maximum number of classes to treat as discrete. Default 10.
+ min_effect_size (float, optional): minimum effect size to plot. Default 0.003.
+ y_prefix (str, optional): y prefix for plot. Default "Influence of ".
+ X_names (List[str], optional): names of predictors. Default None.
+ verbose (bool, optional): print progess. Default True.
+ individual_preds (bool, optional): whether to use plot each bootstrap. Default True.
+ C_vis (np.ndarray, optional): Context bins used to visualize context (n_vis, n_contexts). Default None to construct anew.
+ n_vis (int, optional): Number of bins to use to visualize context. Default 1000.
+ lower_pct (int, optional): Lower percentile for bootstraps. Default 2.5.
+ upper_pct (int, optional): Upper percentile for bootstraps. Default 97.5.
+ classification (bool, optional): Whether to exponentiate the effects. Default True.
+ C_encoders (List[sklearn.preprocessing.LabelEncoder], optional): encoders for each context. Default None.
+ C_means (np.ndarray, optional): means for each context. Default None.
+ C_stds (np.ndarray, optional): standard deviations for each context. Default None.
+ xlabel_prefix (str, optional): prefix for x label. Default "".
+ figname (str, optional): name of figure to save. Default None.
+
+ Returns:
+ None
"""
c_vis = maybe_make_c_vis(C, **kwargs)
n_vis = c_vis.shape[0]
diff --git a/contextualized/analysis/embeddings.py b/contextualized/analysis/embeddings.py
index d21257c..9cb649b 100644
--- a/contextualized/analysis/embeddings.py
+++ b/contextualized/analysis/embeddings.py
@@ -1,9 +1,10 @@
"""
Utilities for plotting embeddings of fitted Contextualized models.
"""
-
+from typing import *
import numpy as np
+import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
@@ -11,16 +12,26 @@
def plot_embedding_for_all_covars(
- reps, covars_df, covars_stds=None, covars_means=None, covars_encoders=None, **kwargs
-):
+ reps: np.ndarray,
+ covars_df: pd.DataFrame,
+ covars_stds: np.ndarray = None,
+ covars_means: np.ndarray = None,
+ covars_encoders: List[Callable] = None,
+ **kwargs,
+) -> None:
"""
Plot embeddings of representations for all covariates in a Pandas dataframe.
- :param reps:
- :param covars_df:
- :param covars_stds: Used to project back to readable values. (Default value = None)
- :param covars_means: Used to project back to readable values. (Default value = None)
- :param covars_encoders: Used to project back to readable values. (Default value = None)
- :param kwargs: Keyword arguments for plotting.
+
+ Args:
+ reps (np.ndarray): Embeddings of shape (n_samples, n_dims).
+ covars_df (pd.DataFrame): DataFrame of covariates.
+ covars_stds (np.ndarray, optional): Standard deviations of covariates. Defaults to None.
+ covars_means (np.ndarray, optional): Means of covariates. Defaults to None.
+ covars_encoders (List[LabelEncoder], optional): Encoders for covariates. Defaults to None.
+ kwargs: Keyword arguments for plotting.
+
+ Returns:
+ None
"""
for i, covar in enumerate(covars_df.columns):
my_labels = covars_df.iloc[:, i].values
@@ -49,17 +60,20 @@ def plot_embedding_for_all_covars(
def plot_lowdim_rep(
- low_dim,
- labels,
+ low_dim: np.ndarray,
+ labels: np.ndarray,
**kwargs,
):
"""
+ Plot a low-dimensional representation of a dataset.
- :param low_dim:
- :param labels:
- :param kwargs:
- Keyword arguments.
+ Args:
+ low_dim (np.ndarray): Low-dimensional representation of shape (n_samples, 2).
+ labels (np.ndarray): Labels of shape (n_samples,).
+ kwargs: Keyword arguments for plotting.
+ Returns:
+ None
"""
if len(set(labels)) < kwargs.get("max_classes_for_discrete", 10): # discrete labels
diff --git a/contextualized/analysis/pvals.py b/contextualized/analysis/pvals.py
index e98efd0..c745199 100644
--- a/contextualized/analysis/pvals.py
+++ b/contextualized/analysis/pvals.py
@@ -2,6 +2,7 @@
Analysis tools for generating pvalues from bootstrap replicates.
"""
+from typing import *
import numpy as np
@@ -10,6 +11,7 @@
get_homogeneous_predictor_effects,
get_heterogeneous_predictor_effects,
)
+from contextualized.easy.wrappers import SKLearnWrapper
def calc_pval_bootstraps_one_sided(estimates, thresh=0, laplace_smoothing=1):
@@ -48,19 +50,19 @@ def calc_pval_bootstraps_one_sided_mean(estimates, laplace_smoothing=1):
)
-def calc_homogeneous_context_effects_pvals(model, C, **kwargs):
+def calc_homogeneous_context_effects_pvals(
+ model: SKLearnWrapper, C: np.ndarray, **kwargs
+) -> np.ndarray:
"""
Calculate p-values for the effects of context.
- Parameters
- ----------
- model : contextualized.models.Model
- C : np.ndarray
+ Args:
+ model (SKLearnWrapper): Model to analyze.
+ C (np.ndarray): Contexts to analyze.
- Returns
- -------
- pvals : np.ndarray of shape (n_contexts, n_outcomes) testing whether the
- sign is consistent across bootstraps
+ Returns:
+ np.ndarray: P-values of shape (n_contexts, n_outcomes) testing whether the
+ sign of the direct effect of context on outcomes is consistent across bootstraps.
"""
_, effects = get_homogeneous_context_effects(model, C, **kwargs)
# effects.shape: (n_contexts, n_bootstraps, n_context_vals, n_outcomes)
@@ -86,19 +88,19 @@ def calc_homogeneous_context_effects_pvals(model, C, **kwargs):
return pvals
-def calc_homogeneous_predictor_effects_pvals(model, C, **kwargs):
+def calc_homogeneous_predictor_effects_pvals(
+ model: SKLearnWrapper, C: np.ndarray, **kwargs
+) -> np.ndarray:
"""
- Calculate p-values for the effects of predictors.
+ Calculate p-values for the context-invariant effects of predictors.
- Parameters
- ----------
- model : contextualized.models.Model
- C : np.ndarray
+ Args:
+ model (SKLearnWrapper): Model to analyze.
+ C (np.ndarray): Contexts to analyze.
- Returns
- -------
- pvals : np.ndarray of shape (n_predictors, n_outcomes) testing whether the
- sign is consistent across bootstraps
+ Returns:
+ np.ndarray: P-values of shape (n_predictors, n_outcomes) testing whether the
+ sign of the context-invariant predictor effects are consistent across bootstraps.
"""
_, effects = get_homogeneous_predictor_effects(model, C, **kwargs)
# effects.shape: (n_predictors, n_bootstraps, n_outcomes)
@@ -126,15 +128,13 @@ def calc_heterogeneous_predictor_effects_pvals(model, C, **kwargs):
"""
Calculate p-values for the heterogeneous effects of predictors.
- Parameters
- ----------
- model : contextualized.models.Model
- C : np.ndarray
+ Args:
+ model (SKLearnWrapper): Model to analyze.
+ C (np.ndarray): Contexts to analyze.
- Returns
- -------
- pvals : np.ndarray of shape (n_contexts, n_predictors, n_outcomes) testing
- whether the sign of the change wrt context is consistent across bootstraps
+ Returns:
+ np.ndarray: P-values of shape (n_contexts, n_predictors, n_outcomes) testing whether the
+ context-varying parameter range is consistent across bootstraps.
"""
_, effects = get_heterogeneous_predictor_effects(model, C, **kwargs)
# effects.shape is (n_contexts, n_predictors, n_bootstraps, n_context_vals, n_outcomes)
diff --git a/contextualized/analysis/utils.py b/contextualized/analysis/utils.py
index e420e4b..439f590 100644
--- a/contextualized/analysis/utils.py
+++ b/contextualized/analysis/utils.py
@@ -2,16 +2,20 @@
Miscellaneous utility functions.
"""
+from typing import *
+
import numpy as np
-def convert_to_one_hot(col):
+def convert_to_one_hot(col: Collection[Any]) -> Tuple[np.ndarray, List[Any]]:
"""
+ Converts a categorical variable to a one-hot vector.
- :param col: np array with observations
-
- returns col converted to one-hot values, and list of one-hot values.
+ Args:
+ col (Collection[Any]): The categorical variable.
+ Returns:
+ Tuple[np.ndarray, List[Any]]: The one-hot vector and the possible values.
"""
vals = list(set(col))
one_hot_vars = np.array([vals.index(x) for x in col], dtype=np.float32)
diff --git a/contextualized/dags/lightning_modules.py b/contextualized/dags/lightning_modules.py
index 9bda98c..a6099dc 100644
--- a/contextualized/dags/lightning_modules.py
+++ b/contextualized/dags/lightning_modules.py
@@ -28,7 +28,12 @@
}
DEFAULT_DAG_LOSS_TYPE = "NOTEARS"
DEFAULT_DAG_LOSS_PARAMS = {
- "NOTEARS": {"alpha": 1e-1, "rho": 1e-2, "tol": 0.25, "use_dynamic_alpha_rho": False},
+ "NOTEARS": {
+ "alpha": 1e-1,
+ "rho": 1e-2,
+ "tol": 0.25,
+ "use_dynamic_alpha_rho": False,
+ },
"DAGMA": {"s": 1, "alpha": 1e0},
"poly": {},
}
@@ -143,13 +148,14 @@ def __init__(
# DAG regularizers
self.ss_dag_params = sample_specific_loss_params["dag"].get(
"params",
- DEFAULT_DAG_LOSS_PARAMS[sample_specific_loss_params["dag"]["loss_type"]].copy(),
+ DEFAULT_DAG_LOSS_PARAMS[
+ sample_specific_loss_params["dag"]["loss_type"]
+ ].copy(),
)
-
self.arch_dag_params = archetype_loss_params["dag"].get(
- "params",
- DEFAULT_DAG_LOSS_PARAMS[archetype_loss_params["dag"]["loss_type"]].copy()
+ "params",
+ DEFAULT_DAG_LOSS_PARAMS[archetype_loss_params["dag"]["loss_type"]].copy(),
)
self.val_dag_loss_params = {"alpha": 1e0, "rho": 1e0}
@@ -415,7 +421,8 @@ def _maybe_update_alpha_rho(self, epoch_dag_loss, dag_params):
"""
if (
dag_params.get("use_dynamic_alpha_rho", False)
- and epoch_dag_loss > dag_params.get("tol", .25) * dag_params.get("h_old", 0)
+ and epoch_dag_loss
+ > dag_params.get("tol", 0.25) * dag_params.get("h_old", 0)
and dag_params["alpha"] < 1e12
and dag_params["rho"] < 1e12
):
diff --git a/contextualized/easy/ContextualGAM.py b/contextualized/easy/ContextualGAM.py
index 9b50bca..5ea6cda 100644
--- a/contextualized/easy/ContextualGAM.py
+++ b/contextualized/easy/ContextualGAM.py
@@ -9,7 +9,18 @@
class ContextualGAMClassifier(ContextualizedClassifier):
"""
- A GAM as context encoder with a classifier on top.
+ The Contextual GAM Classifier separates and interprets the effect of context in context-varying decisions and classifiers, such as heterogeneous disease diagnoses.
+ Implemented as a Contextual Generalized Additive Model with a classifier on top.
+ Always uses a Neural Additive Model ("ngam") encoder for interpretability.
+ See `this paper `__
+ for more details.
+
+ Args:
+ n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1.
+ num_archetypes (int, optional): Number of archetypes to use. Defaults to 0, which used the NaiveMetaModel. If > 0, uses archetypes in the ContextualizedMetaModel.
+ alpha (float, optional): Regularization strength. Defaults to 0.0.
+ mu_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization applies to context-specific parameters or context-specific offsets. Defaults to 0.0.
+ l1_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization penalizes l1 vs l2 parameter norms. Defaults to 0.0.
"""
def __init__(self, **kwargs):
@@ -19,7 +30,18 @@ def __init__(self, **kwargs):
class ContextualGAMRegressor(ContextualizedRegressor):
"""
- A GAM as context encoder with a regressor on top.
+ The Contextual GAM Regressor separates and interprets the effect of context in context-varying relationships, such as heterogeneous treatment effects.
+ Implemented as a Contextual Generalized Additive Model with a linear regressor on top.
+ Always uses a Neural Additive Model ("ngam") encoder for interpretability.
+ See `this paper `__
+ for more details.
+
+ Args:
+ n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1.
+ num_archetypes (int, optional): Number of archetypes to use. Defaults to 0, which used the NaiveMetaModel. If > 0, uses archetypes in the ContextualizedMetaModel.
+ alpha (float, optional): Regularization strength. Defaults to 0.0.
+ mu_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization applies to context-specific parameters or context-specific offsets. Defaults to 0.0.
+ l1_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization penalizes l1 vs l2 parameter norms. Defaults to 0.0.
"""
def __init__(self, **kwargs):
diff --git a/contextualized/easy/ContextualizedClassifier.py b/contextualized/easy/ContextualizedClassifier.py
index a99d060..f5cab64 100644
--- a/contextualized/easy/ContextualizedClassifier.py
+++ b/contextualized/easy/ContextualizedClassifier.py
@@ -11,7 +11,16 @@
class ContextualizedClassifier(ContextualizedRegressor):
"""
- sklearn-like interface to Contextualized Classifiers.
+ Contextualized Logistic Regression reveals context-dependent decisions and decision boundaries.
+ Implemented as a ContextualizedRegressor with logistic link function and binary cross-entropy loss.
+
+ Args:
+ n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1.
+ num_archetypes (int, optional): Number of archetypes to use. Defaults to 0, which used the NaiveMetaModel. If > 0, uses archetypes in the ContextualizedMetaModel.
+ encoder_type (str, optional): Type of encoder to use ("mlp", "ngam", "linear"). Defaults to "mlp".
+ alpha (float, optional): Regularization strength. Defaults to 0.0.
+ mu_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization applies to context-specific parameters or context-specific offsets. Defaults to 0.0.
+ l1_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization penalizes l1 vs l2 parameter norms. Defaults to 0.0.
"""
def __init__(self, **kwargs):
@@ -20,14 +29,15 @@ def __init__(self, **kwargs):
super().__init__(**kwargs)
def predict(self, C, X, individual_preds=False, **kwargs):
- """
- Predict outcomes from context C and predictors X.
+ """Predict binary outcomes from context C and predictors X.
- :param C:
- :param X:
- :param individual_preds:
- :param **kwargs:
+ Args:
+ C (np.ndarray): Context array of shape (n_samples, n_context_features)
+ X (np.ndarray): Predictor array of shape (N, n_features)
+ individual_preds (bool, optional): Whether to return individual predictions for each model. Defaults to False.
+ Returns:
+ Union[np.ndarray, List[np.ndarray]]: The binary outcomes predicted by the context-specific models (n_samples, y_dim). Returned as lists of individual bootstraps if individual_preds is True.
"""
return np.round(super().predict(C, X, individual_preds, **kwargs))
@@ -35,10 +45,13 @@ def predict_proba(self, C, X, **kwargs):
"""
Predict probabilities of outcomes from context C and predictors X.
- :param C:
- :param X:
- :param **kwargs:
+ Args:
+ C (np.ndarray): Context array of shape (n_samples, n_context_features)
+ X (np.ndarray): Predictor array of shape (N, n_features)
+ individual_preds (bool, optional): Whether to return individual predictions for each model. Defaults to False.
+ Returns:
+ Union[np.ndarray, List[np.ndarray]]: The outcome probabilities predicted by the context-specific models (n_samples, y_dim). Returned as lists of individual bootstraps if individual_preds is True.
"""
# Returns a np array of shape N samples, K outcomes, 2.
probs = super().predict(C, X, **kwargs)
diff --git a/contextualized/easy/ContextualizedNetworks.py b/contextualized/easy/ContextualizedNetworks.py
index 0cae6bc..ae3b5a7 100644
--- a/contextualized/easy/ContextualizedNetworks.py
+++ b/contextualized/easy/ContextualizedNetworks.py
@@ -1,16 +1,21 @@
"""
sklearn-like interface to Contextualized Networks.
"""
+from typing import *
+
import numpy as np
from contextualized.easy.wrappers import SKLearnWrapper
from contextualized.regression.trainers import CorrelationTrainer, MarkovTrainer
from contextualized.regression.lightning_modules import (
ContextualizedCorrelation,
- # TasksplitContextualizedCorrelation, # TODO: Incorporate Tasksplit
ContextualizedMarkovGraph,
)
-from contextualized.dags.lightning_modules import NOTMAD, DEFAULT_DAG_LOSS_TYPE, DEFAULT_DAG_LOSS_PARAMS
+from contextualized.dags.lightning_modules import (
+ NOTMAD,
+ DEFAULT_DAG_LOSS_TYPE,
+ DEFAULT_DAG_LOSS_PARAMS,
+)
from contextualized.dags.trainers import GraphTrainer
from contextualized.dags.graph_utils import dag_pred_np
@@ -20,28 +25,77 @@ class ContextualizedNetworks(SKLearnWrapper):
sklearn-like interface to Contextualized Networks.
"""
- def _split_train_data(self, C, X, **kwargs):
- return super()._split_train_data(C, X, Y_required=False, **kwargs)
+ def _split_train_data(
+ self, C: np.ndarray, X: np.ndarray, **kwargs
+ ) -> Tuple[List[np.ndarray], List[np.ndarray]]:
+ """Splits data into train and test sets.
- def predict_networks(self, C, with_offsets=False, **kwargs):
+ Args:
+ C (np.ndarray): Contextual features for each sample.
+ X (np.ndarray): The data matrix.
+
+ Returns:
+ Tuple[List[np.ndarray], List[np.ndarray]]: The train and test sets for C and X as ([C_train, X_train], [C_test, X_test]).
"""
- Predicts context-specific networks.
+ return super()._split_train_data(C, X, Y_required=False, **kwargs)
+
+ def predict_networks(
+ self,
+ C: np.ndarray,
+ with_offsets: bool,
+ individual_preds: bool = False,
+ **kwargs,
+ ) -> Union[
+ np.ndarray,
+ List[np.ndarray],
+ Tuple[np.ndarray, np.ndarray],
+ Tuple[List[np.ndarray], List[np.ndarray]],
+ ]:
+ """Predicts context-specific networks given contextual features.
+
+ Args:
+ C (np.ndarray): Contextual features for each sample (n_samples, n_context_features)
+ with_offsets (bool, optional): If True, returns both the network parameters and offsets. Defaults to False.
+ individual_preds (bool, optional): If True, returns the predictions for each bootstrap. Defaults to False.
+
+ Returns:
+ Union[np.ndarray, List[np.ndarray], Tuple[np.ndarray, np.ndarray], Tuple[List[np.ndarray], List[np.ndarray]]]: The predicted network parameters (and offsets if with_offsets is True). Returned as lists of individual bootstraps if individual_preds is True.
"""
betas, mus = self.predict_params(C, uses_y=False, **kwargs)
if with_offsets:
return betas, mus
return betas
- def predict_X(self, C, X, **kwargs):
- """
- Predicts X based on context-specific networks.
+ def predict_X(
+ self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, **kwargs
+ ) -> Union[np.ndarray, List[np.ndarray]]:
+ """Reconstructs the data matrix based on predicted contextualized networks and the true data matrix.
+ Useful for measuring reconstruction error or for imputation.
+
+ Args:
+ C (np.ndarray): Contextual features for each sample (n_samples, n_context_features)
+ X (np.ndarray): The data matrix (n_samples, n_features)
+ individual_preds (bool, optional): If True, returns the predictions for each bootstrap. Defaults to False.
+ **kwargs: Keyword arguments for the Lightning trainer's predict_y method.
+
+ Returns:
+ Union[np.ndarray, List[np.ndarray]]: The predicted data matrix, or matrices for each bootstrap if individual_preds is True (n_samples, n_features).
"""
- return self.predict(C, X, **kwargs)
+ return self.predict(C, X, individual_preds=individual_preds, **kwargs)
class ContextualizedCorrelationNetworks(ContextualizedNetworks):
"""
- Easy interface to Contextualized Correlation Networks.
+ Contextualized Correlation Networks reveal context-varying feature correlations, interaction strengths, dependencies in feature groups.
+ Uses the Contextualized Networks model, see the `paper `__ for detailed estimation procedures.
+
+ Args:
+ n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1.
+ num_archetypes (int, optional): Number of archetypes to use. Defaults to 10. Always uses archetypes in the ContextualizedMetaModel.
+ encoder_type (str, optional): Type of encoder to use ("mlp", "ngam", "linear"). Defaults to "mlp".
+ alpha (float, optional): Regularization strength. Defaults to 0.0.
+ mu_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization applies to context-specific parameters or context-specific offsets. Defaults to 0.0.
+ l1_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization penalizes l1 vs l2 parameter norms. Defaults to 0.0.
"""
def __init__(self, **kwargs):
@@ -49,18 +103,25 @@ def __init__(self, **kwargs):
ContextualizedCorrelation, [], [], CorrelationTrainer, **kwargs
)
- def predict_correlation(self, C, individual_preds=True, squared=True, **kwargs):
- """
- Predict correlation matrices.
+ def predict_correlation(
+ self, C: np.ndarray, individual_preds: bool = True, squared: bool = True
+ ) -> Union[np.ndarray, List[np.ndarray]]:
+ """Predicts context-specific correlations between features.
+
+ Args:
+ C (Numpy ndarray): Contextual features for each sample (n_samples, n_context_features)
+ individual_preds (bool, optional): If True, returns the predictions for each bootstrap. Defaults to True.
+ squared (bool, optional): If True, returns the squared correlations. Defaults to True.
+
+ Returns:
+ Union[np.ndarray, List[np.ndarray]]: The predicted context-specific correlation matrices, or matrices for each bootstrap if individual_preds is True (n_samples, n_features, n_features).
"""
get_dataloader = lambda i: self.models[i].dataloader(
C, np.zeros((len(C), self.x_dim))
)
rhos = np.array(
[
- self.trainers[i].predict_params(
- self.models[i], get_dataloader(i), **kwargs
- )[0]
+ self.trainers[i].predict_params(self.models[i], get_dataloader(i))[0]
for i in range(len(self.models))
]
)
@@ -73,9 +134,18 @@ def predict_correlation(self, C, individual_preds=True, squared=True, **kwargs):
return np.square(np.mean(rhos, axis=0))
return np.mean(rhos)
- def measure_mses(self, C, X, individual_preds=False):
- """
- Measure mean-squared errors.
+ def measure_mses(
+ self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False
+ ) -> Union[np.ndarray, List[np.ndarray]]:
+ """Measures mean-squared errors.
+
+ Args:
+ C (np.ndarray): Contextual features for each sample (n_samples, n_context_features)
+ X (np.ndarray): The data matrix (n_samples, n_features)
+ individual_preds (bool, optional): If True, returns the predictions for each bootstrap. Defaults to False.
+
+ Returns:
+ Union[np.ndarray, List[np.ndarray]]: The mean-squared errors for each sample, or for each bootstrap if individual_preds is True (n_samples).
"""
betas, mus = self.predict_networks(C, individual_preds=True, with_offsets=True)
mses = np.zeros((len(betas), len(C))) # n_bootstraps x n_samples
@@ -92,15 +162,35 @@ def measure_mses(self, C, X, individual_preds=False):
class ContextualizedMarkovNetworks(ContextualizedNetworks):
"""
- Easy interface to Contextualized Markov Networks.
+ Contextualized Markov Networks reveal context-varying feature dependencies, cliques, and modules.
+ Implemented as Contextualized Gaussian Precision Matrices, directly interpretable as Markov Networks.
+ Uses the Contextualized Networks model, see the `paper `__ for detailed estimation procedures.
+
+ Args:
+ n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1.
+ num_archetypes (int, optional): Number of archetypes to use. Defaults to 10. Always uses archetypes in the ContextualizedMetaModel.
+ encoder_type (str, optional): Type of encoder to use ("mlp", "ngam", "linear"). Defaults to "mlp".
+ alpha (float, optional): Regularization strength. Defaults to 0.0.
+ mu_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization applies to context-specific parameters or context-specific offsets. Defaults to 0.0.
+ l1_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization penalizes l1 vs l2 parameter norms. Defaults to 0.0.
"""
def __init__(self, **kwargs):
super().__init__(ContextualizedMarkovGraph, [], [], MarkovTrainer, **kwargs)
- def predict_precisions(self, C, individual_preds=True):
- """
- Predict precision matrices.
+ def predict_precisions(
+ self, C: np.ndarray, individual_preds: bool = True
+ ) -> Union[np.ndarray, List[np.ndarray]]:
+ """Predicts context-specific precision matrices.
+ Can be converted to context-specific Markov networks by binarizing the networks and setting all non-zero entries to 1.
+ Can be converted to context-specific covariance matrices by taking the inverse.
+
+ Args:
+ C (np.ndarray): Contextual features for each sample (n_samples, n_context_features)
+ individual_preds (bool, optional): If True, returns the predictions for each bootstrap. Defaults to True.
+
+ Returns:
+ Union[np.ndarray, List[np.ndarray]]: The predicted context-specific Markov networks as precision matrices, or matrices for each bootstrap if individual_preds is True (n_samples, n_features, n_features).
"""
get_dataloader = lambda i: self.models[i].dataloader(
C, np.zeros((len(C), self.x_dim))
@@ -115,9 +205,18 @@ def predict_precisions(self, C, individual_preds=True):
return precisions
return np.mean(precisions, axis=0)
- def measure_mses(self, C, X, individual_preds=False):
- """
- Measure mean-squared errors.
+ def measure_mses(
+ self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False
+ ) -> Union[np.ndarray, List[np.ndarray]]:
+ """Measures mean-squared errors.
+
+ Args:
+ C (np.ndarray): Contextual features for each sample (n_samples, n_context_features)
+ X (np.ndarray): The data matrix (n_samples, n_features)
+ individual_preds (bool, optional): If True, returns the predictions for each bootstrap. Defaults to False.
+
+ Returns:
+ Union[np.ndarray, List[np.ndarray]]: The mean-squared errors for each sample, or for each bootstrap if individual_preds is True (n_samples).
"""
betas, mus = self.predict_networks(C, individual_preds=True, with_offsets=True)
mses = np.zeros((len(betas), len(C))) # n_bootstraps x n_samples
@@ -140,18 +239,40 @@ def measure_mses(self, C, X, individual_preds=False):
class ContextualizedBayesianNetworks(ContextualizedNetworks):
"""
- Easy interface to Contextualized Bayesian Networks.
- Uses NOTMAD model.
- See this paper:
- https://arxiv.org/abs/2111.01104
- for more details.
+ Contextualized Bayesian Networks and Directed Acyclic Graphs (DAGs) reveal context-dependent causal relationships, effect sizes, and variable ordering.
+ Uses the NOTMAD model, see the `paper `__ for detailed estimation procedures.
+
+ Args:
+ n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1.
+ num_archetypes (int, optional): Number of archetypes to use. Defaults to 16. Always uses archetypes in the ContextualizedMetaModel.
+ encoder_type (str, optional): Type of encoder to use ("mlp", "ngam", "linear"). Defaults to "mlp".
+ archetype_dag_loss_type (str, optional): The type of loss to use for the archetype loss. Defaults to "l1".
+ archetype_l1 (float, optional): The strength of the l1 regularization for the archetype loss. Defaults to 0.0.
+ archetype_dag_params (dict, optional): Parameters for the archetype loss. Defaults to {"loss_type": "l1", "params": {"alpha": 0.0, "rho": 0.0, "s": 0.0, "tol": 1e-4}}.
+ archetype_dag_loss_params (dict, optional): Parameters for the archetype loss. Defaults to {"alpha": 0.0, "rho": 0.0, "s": 0.0, "tol": 1e-4}.
+ archetype_alpha (float, optional): The strength of the alpha regularization for the archetype loss. Defaults to 0.0.
+ archetype_rho (float, optional): The strength of the rho regularization for the archetype loss. Defaults to 0.0.
+ archetype_s (float, optional): The strength of the s regularization for the archetype loss. Defaults to 0.0.
+ archetype_tol (float, optional): The tolerance for the archetype loss. Defaults to 1e-4.
+ archetype_use_dynamic_alpha_rho (bool, optional): Whether to use dynamic alpha and rho for the archetype loss. Defaults to False.
+ init_mat (np.ndarray, optional): The initial adjacency matrix for the archetype loss. Defaults to None.
+ num_factors (int, optional): The number of factors for the archetype loss. Defaults to 0.
+ factor_mat_l1 (float, optional): The strength of the l1 regularization for the factor matrix for the archetype loss. Defaults to 0.
+ sample_specific_dag_loss_type (str, optional): The type of loss to use for the sample-specific loss. Defaults to "l1".
+ sample_specific_alpha (float, optional): The strength of the alpha regularization for the sample-specific loss. Defaults to 0.0.
+ sample_specific_rho (float, optional): The strength of the rho regularization for the sample-specific loss. Defaults to 0.0.
+ sample_specific_s (float, optional): The strength of the s regularization for the sample-specific loss. Defaults to 0.0.
+ sample_specific_tol (float, optional): The tolerance for the sample-specific loss. Defaults to 1e-4.
+ sample_specific_use_dynamic_alpha_rho (bool, optional): Whether to use dynamic alpha and rho for the sample-specific loss. Defaults to False.
"""
def _parse_private_init_kwargs(self, **kwargs):
"""
- Parses private init kwargs.
- """
+ Parses the kwargs for the NOTMAD model.
+ Args:
+ **kwargs: Keyword arguments for the NOTMAD model, including the encoder, archetype loss, sample-specific loss, and optimization parameters.
+ """
# Encoder Parameters
self._init_kwargs["model"]["encoder_kwargs"] = {
"type": kwargs.pop(
@@ -163,9 +284,11 @@ def _parse_private_init_kwargs(self, **kwargs):
"link_fn": self.constructor_kwargs["encoder_kwargs"]["link_fn"],
},
}
-
+
# Archetype-specific parameters
- archetype_dag_loss_type = kwargs.pop("archetype_dag_loss_type", DEFAULT_DAG_LOSS_TYPE)
+ archetype_dag_loss_type = kwargs.pop(
+ "archetype_dag_loss_type", DEFAULT_DAG_LOSS_TYPE
+ )
self._init_kwargs["model"]["archetype_loss_params"] = {
"l1": kwargs.get("archetype_l1", 0.0),
"dag": kwargs.get(
@@ -185,9 +308,11 @@ def _parse_private_init_kwargs(self, **kwargs):
}
if self._init_kwargs["model"]["archetype_loss_params"]["num_archetypes"] <= 0:
- print("WARNING: num_archetypes is 0. NOTMAD requires archetypes. Setting num_archetypes to 16.")
+ print(
+ "WARNING: num_archetypes is 0. NOTMAD requires archetypes. Setting num_archetypes to 16."
+ )
self._init_kwargs["model"]["archetype_loss_params"]["num_archetypes"] = 16
-
+
# Possibly update values with convenience parameters
for param, value in self._init_kwargs["model"]["archetype_loss_params"]["dag"][
"params"
@@ -213,11 +338,11 @@ def _parse_private_init_kwargs(self, **kwargs):
},
),
}
-
+
# Possibly update values with convenience parameters
- for param, value in self._init_kwargs["model"]["sample_specific_loss_params"]["dag"][
- "params"
- ].items():
+ for param, value in self._init_kwargs["model"]["sample_specific_loss_params"][
+ "dag"
+ ]["params"].items():
self._init_kwargs["model"]["sample_specific_loss_params"]["dag"]["params"][
param
] = kwargs.pop(f"sample_specific_{param}", value)
@@ -227,7 +352,7 @@ def _parse_private_init_kwargs(self, **kwargs):
"learning_rate": kwargs.pop("learning_rate", 1e-3),
"step": kwargs.pop("step", 50),
}
-
+
return [
"archetype_dag_loss_type",
"archetype_l1",
@@ -271,39 +396,56 @@ def __init__(self, **kwargs):
**kwargs,
)
- def predict_params(self, C, **kwargs):
- """
+ def predict_params(
+ self, C: np.ndarray, **kwargs
+ ) -> Union[np.ndarray, List[np.ndarray]]:
+ """Predicts context-specific Bayesian network parameters as linear coefficients in a linear structural equation model (SEM).
- :param C:
- :param individual_preds: (Default value = False)
+ Args:
+ C (np.ndarray): Contextual features for each sample (n_samples, n_context_features)
+ **kwargs: Keyword arguments for the contextualized.dags.GraphTrainer's predict_params method.
+ Returns:
+ Union[np.ndarray, List[np.ndarray]]: The linear coefficients of the predicted context-specific Bayesian network parameters (n_samples, n_features, n_features). Returned as lists of individual bootstraps if individual_preds is True.
"""
- # Returns betas
- # TODO: No mus for NOTMAD at present.
- return super().predict_params(
- C, model_includes_mus=False, **kwargs
- )
+ # No mus for NOTMAD at present.
+ return super().predict_params(C, model_includes_mus=False, **kwargs)
- def predict_networks(self, C, **kwargs):
- """
- Predicts context-specific networks.
+ def predict_networks(
+ self, C: np.ndarray, project_to_dag: bool = True, **kwargs
+ ) -> Union[np.ndarray, List[np.ndarray]]:
+ """Predicts context-specific Bayesian networks.
+
+ Args:
+ C (np.ndarray): Contextual features for each sample (n_samples, n_context_features)
+ project_to_dag (bool, optional): If True, guarantees returned graphs are DAGs by trimming edges until acyclicity is satisified. Defaults to True.
+ **kwargs: Keyword arguments for the contextualized.dags.GraphTrainer's predict_params method.
+
+ Returns:
+ Union[np.ndarray, List[np.ndarray]]: The linear coefficients of the predicted context-specific Bayesian network parameters (n_samples, n_features, n_features). Returned as lists of individual bootstraps if individual_preds is True.
"""
if kwargs.pop("with_offsets", False):
print("No offsets can be returned by NOTMAD.")
betas = self.predict_params(
- C,
- uses_y=False,
- project_to_dag=kwargs.pop("project_to_dag", True),
- **kwargs
+ C, uses_y=False, project_to_dag=project_to_dag, **kwargs
)
-
return betas
- def measure_mses(self, C, X, individual_preds=False):
- """
- Measure mean-squared errors.
+ def measure_mses(
+ self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, **kwargs
+ ) -> Union[np.ndarray, List[np.ndarray]]:
+ """Measures mean-squared errors.
+
+ Args:
+ C (np.ndarray): Contextual features for each sample (n_samples, n_context_features)
+ X (np.ndarray): The data matrix (n_samples, n_features)
+ individual_preds (bool, optional): If True, returns the predictions for each bootstrap. Defaults to False.
+ **kwargs: Keyword arguments for the contextualized.dags.GraphTrainer's predict_params method.
+
+ Returns:
+ Union[np.ndarray, List[np.ndarray]]: The mean-squared errors for each sample, or for each bootstrap if individual_preds is True (n_samples).
"""
- betas = self.predict_networks(C, individual_preds=True)
+ betas = self.predict_networks(C, individual_preds=True, **kwargs)
mses = np.zeros((len(betas), len(C))) # n_bootstraps x n_samples
for bootstrap in range(len(betas)):
X_pred = dag_pred_np(X, betas[bootstrap])
diff --git a/contextualized/easy/ContextualizedRegressor.py b/contextualized/easy/ContextualizedRegressor.py
index 117097e..8f7fcae 100644
--- a/contextualized/easy/ContextualizedRegressor.py
+++ b/contextualized/easy/ContextualizedRegressor.py
@@ -14,7 +14,17 @@
class ContextualizedRegressor(SKLearnWrapper):
"""
- sklearn-like interface to Contextualized Regression.
+ Contextualized Linear Regression quantifies context-varying linear relationships.
+
+ Args:
+ n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1.
+ num_archetypes (int, optional): Number of archetypes to use. Defaults to 0, which used the NaiveMetaModel. If > 0, uses archetypes in the ContextualizedMetaModel.
+ encoder_type (str, optional): Type of encoder to use ("mlp", "ngam", "linear"). Defaults to "mlp".
+ loss_fn (torch.nn.Module, optional): Loss function. Defaults to LOSSES["mse"].
+ link_fn (torch.nn.Module, optional): Link function. Defaults to LINK_FUNCTIONS["identity"].
+ alpha (float, optional): Regularization strength. Defaults to 0.0.
+ mu_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization applies to context-specific parameters or context-specific offsets. Defaults to 0.0.
+ l1_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization penalizes l1 vs l2 parameter norms. Defaults to 0.0.
"""
def __init__(self, **kwargs):
diff --git a/contextualized/easy/tests/test_bayesian_networks.py b/contextualized/easy/tests/test_bayesian_networks.py
index 60ff6eb..5ba4c46 100644
--- a/contextualized/easy/tests/test_bayesian_networks.py
+++ b/contextualized/easy/tests/test_bayesian_networks.py
@@ -28,13 +28,17 @@ def setUp(self):
def test_bayesian_factors(self):
"""Test case for ContextualizedBayesianNetworks."""
- model = ContextualizedBayesianNetworks(encoder_type="ngam", num_archetypes=16, num_factors=2)
+ model = ContextualizedBayesianNetworks(
+ encoder_type="ngam", num_archetypes=16, num_factors=2
+ )
model.fit(self.C, self.X, max_epochs=10)
networks = model.predict_networks(self.C, individual_preds=False)
assert np.shape(networks) == (self.n_samples, self.x_dim, self.x_dim)
networks = model.predict_networks(self.C, factors=True)
assert np.shape(networks) == (self.n_samples, 2, 2)
- model = ContextualizedBayesianNetworks(encoder_type="ngam", num_archetypes=16, num_factors=2)
+ model = ContextualizedBayesianNetworks(
+ encoder_type="ngam", num_archetypes=16, num_factors=2
+ )
self._quicktest(model, self.C, self.X, max_epochs=10, learning_rate=1e-3)
def test_bayesian_default(self):
@@ -43,7 +47,9 @@ def test_bayesian_default(self):
def test_bayesian_val_split(self):
model = ContextualizedBayesianNetworks()
- self._quicktest(model, self.C, self.X, max_epochs=10, learning_rate=1e-3, val_split=0.5)
+ self._quicktest(
+ model, self.C, self.X, max_epochs=10, learning_rate=1e-3, val_split=0.5
+ )
def test_bayesian_archetypes(self):
model = ContextualizedBayesianNetworks(num_archetypes=16)
@@ -61,12 +67,16 @@ def test_bayesian_encoder(self):
assert np.shape(networks) == (self.n_samples, self.x_dim, self.x_dim)
def test_bayesian_acyclicity(self):
- model = ContextualizedBayesianNetworks(archetype_dag_loss_type="DAGMA", num_archetypes=16)
+ model = ContextualizedBayesianNetworks(
+ archetype_dag_loss_type="DAGMA", num_archetypes=16
+ )
self._quicktest(model, self.C, self.X, max_epochs=10, learning_rate=1e-3)
networks = model.predict_networks(self.C, individual_preds=False)
assert np.shape(networks) == (self.n_samples, self.x_dim, self.x_dim)
- model = ContextualizedBayesianNetworks(archetype_dag_loss_type="poly", num_archetypes=16)
+ model = ContextualizedBayesianNetworks(
+ archetype_dag_loss_type="poly", num_archetypes=16
+ )
self._quicktest(model, self.C, self.X, max_epochs=10, learning_rate=1e-3)
networks = model.predict_networks(self.C, individual_preds=False)
assert np.shape(networks) == (self.n_samples, self.x_dim, self.x_dim)
diff --git a/contextualized/easy/tests/test_correlation_networks.py b/contextualized/easy/tests/test_correlation_networks.py
index 90cd890..52dbd3d 100644
--- a/contextualized/easy/tests/test_correlation_networks.py
+++ b/contextualized/easy/tests/test_correlation_networks.py
@@ -33,7 +33,9 @@ def test_correlation(self):
model = ContextualizedCorrelationNetworks()
self._quicktest(model, self.C, self.X, max_epochs=10, learning_rate=1e-3)
- self._quicktest(model, self.C, self.X, max_epochs=10, learning_rate=1e-3, val_split=0.5)
+ self._quicktest(
+ model, self.C, self.X, max_epochs=10, learning_rate=1e-3, val_split=0.5
+ )
model = ContextualizedCorrelationNetworks(num_archetypes=16)
self._quicktest(model, self.C, self.X, max_epochs=10, learning_rate=1e-3)
diff --git a/contextualized/easy/tests/test_markov_networks.py b/contextualized/easy/tests/test_markov_networks.py
index eb8f1b4..b778321 100644
--- a/contextualized/easy/tests/test_markov_networks.py
+++ b/contextualized/easy/tests/test_markov_networks.py
@@ -30,7 +30,9 @@ def test_markov(self):
"""Test Case for ContextualizedMarkovNetworks."""
model = ContextualizedMarkovNetworks()
self._quicktest(model, self.C, self.X, max_epochs=10, learning_rate=1e-3)
- self._quicktest(model, self.C, self.X, max_epochs=10, learning_rate=1e-3, val_split=0.5)
+ self._quicktest(
+ model, self.C, self.X, max_epochs=10, learning_rate=1e-3, val_split=0.5
+ )
model = ContextualizedMarkovNetworks(num_archetypes=16)
self._quicktest(model, self.C, self.X, max_epochs=10, learning_rate=1e-3)
diff --git a/contextualized/easy/tests/test_regressor.py b/contextualized/easy/tests/test_regressor.py
index e3f66e2..a0b70a3 100644
--- a/contextualized/easy/tests/test_regressor.py
+++ b/contextualized/easy/tests/test_regressor.py
@@ -90,7 +90,7 @@ def test_regressor(self):
learning_rate=1e-3,
es_patience=float("inf"),
)
-
+
# Check smaller Y.
model = ContextualizedRegressor(
num_archetypes=4, alpha=1e-1, l1_ratio=0.5, mu_ratio=0.1
diff --git a/contextualized/easy/wrappers/SKLearnWrapper.py b/contextualized/easy/wrappers/SKLearnWrapper.py
index fed4090..33de005 100644
--- a/contextualized/easy/wrappers/SKLearnWrapper.py
+++ b/contextualized/easy/wrappers/SKLearnWrapper.py
@@ -3,6 +3,8 @@
"""
import copy
import os
+from typing import *
+
import numpy as np
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
@@ -28,6 +30,19 @@
class SKLearnWrapper:
"""
An sklearn-like wrapper for Contextualized models.
+
+ Args:
+ base_constructor (class): The base class to construct the model.
+ extra_model_kwargs (dict): Extra kwargs to pass to the model constructor.
+ extra_data_kwargs (dict): Extra kwargs to pass to the dataloader constructor.
+ trainer_constructor (class): The trainer class to use.
+ n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1.
+ encoder_type (str, optional): Type of encoder to use ("mlp", "ngam", "linear"). Defaults to "mlp".
+ loss_fn (torch.nn.Module, optional): Loss function. Defaults to LOSSES["mse"].
+ link_fn (torch.nn.Module, optional): Link function. Defaults to LINK_FUNCTIONS["identity"].
+ alpha (float, optional): Regularization strength. Defaults to 0.0.
+ mu_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization applies to context-specific parameters or context-specific offsets.
+ l1_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization penalizes l1 vs l2 parameter norms.
"""
def _set_defaults(self):
@@ -44,12 +59,12 @@ def _set_defaults(self):
self.default_encoder_type = DEFAULT_ENCODER_TYPE
def __init__(
- self,
- base_constructor,
- extra_model_kwargs,
- extra_data_kwargs,
- trainer_constructor,
- **kwargs,
+ self,
+ base_constructor,
+ extra_model_kwargs,
+ extra_data_kwargs,
+ trainer_constructor,
+ **kwargs,
):
self._set_defaults()
self.base_constructor = base_constructor
@@ -69,7 +84,7 @@ def __init__(
"test_batch_size",
"C_val",
"X_val",
- "val_split"
+ "val_split",
],
"model": [
"loss_fn",
@@ -137,14 +152,18 @@ def __init__(
if k not in self.constructor_kwargs and k not in self.convenience_kwargs
}
# Some args will not be ignored by wrapper because sub-class will handle them.
- #self.private_kwargs = kwargs.pop("private_kwargs", [])
- #self.private_kwargs.append("private_kwargs")
+ # self.private_kwargs = kwargs.pop("private_kwargs", [])
+ # self.private_kwargs.append("private_kwargs")
# Add Predictor-Specific kwargs for parsing.
- self._init_kwargs, unrecognized_general_kwargs = self._organize_kwargs(**self.not_constructor_kwargs)
+ self._init_kwargs, unrecognized_general_kwargs = self._organize_kwargs(
+ **self.not_constructor_kwargs
+ )
for key, value in self.constructor_kwargs.items():
self._init_kwargs["model"][key] = value
recognized_private_init_kwargs = self._parse_private_init_kwargs(**kwargs)
- for kwarg in set(unrecognized_general_kwargs) - set(recognized_private_init_kwargs):
+ for kwarg in set(unrecognized_general_kwargs) - set(
+ recognized_private_init_kwargs
+ ):
print(f"Received unknown keyword argument {kwarg}, probably ignoring.")
def _organize_and_expand_fit_kwargs(self, **kwargs):
@@ -175,8 +194,8 @@ def maybe_add_kwarg(category, kwarg, default_val):
maybe_add_kwarg("model", "x_dim", self.x_dim)
maybe_add_kwarg("model", "y_dim", self.y_dim)
if (
- "num_archetypes" in organized_kwargs["model"]
- and organized_kwargs["model"]["num_archetypes"] == 0
+ "num_archetypes" in organized_kwargs["model"]
+ and organized_kwargs["model"]["num_archetypes"] == 0
):
del organized_kwargs["model"]["num_archetypes"]
@@ -212,7 +231,6 @@ def maybe_add_kwarg(category, kwarg, default_val):
maybe_add_kwarg("trainer", "accelerator", self.accelerator)
return organized_kwargs
-
def _parse_private_fit_kwargs(self, **kwargs):
"""
Parse private (model-specific) kwargs passed to fit function.
@@ -234,8 +252,9 @@ def _update_acceptable_kwargs(self, category, new_kwargs, acceptable=True):
If acceptable=False, the new kwargs will be removed from the list of acceptable kwargs.
"""
if acceptable:
- self.acceptable_kwargs[category] = list(set(
- self.acceptable_kwargs[category]).union(set(new_kwargs)))
+ self.acceptable_kwargs[category] = list(
+ set(self.acceptable_kwargs[category]).union(set(new_kwargs))
+ )
else:
self.acceptable_kwargs[category] = list(
set(self.acceptable_kwargs[category]) - set(new_kwargs)
@@ -252,7 +271,7 @@ def _organize_kwargs(self, **kwargs):
organized_kwargs = {category: {} for category in self.acceptable_kwargs}
unrecognized_kwargs = []
for kwarg, value in kwargs.items():
- #if kwarg in self.private_kwargs:
+ # if kwarg in self.private_kwargs:
# continue
not_found = True
for category, category_kwargs in self.acceptable_kwargs.items():
@@ -367,11 +386,18 @@ def _build_dataloaders(self, model, train_data, val_data, **kwargs):
return train_dataloader, val_dataloader
- def predict(self, C, X, individual_preds=False, **kwargs):
- """
- :param C:
- :param X:
- :param individual_preds: (Default value = False)
+ def predict(
+ self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, **kwargs
+ ) -> Union[np.ndarray, List[np.ndarray]]:
+ """Predict outcomes from context C and predictors X.
+
+ Args:
+ C (np.ndarray): Context array of shape (n_samples, n_context_features)
+ X (np.ndarray): Predictor array of shape (N, n_features)
+ individual_preds (bool, optional): Whether to return individual predictions for each model. Defaults to False.
+
+ Returns:
+ Union[np.ndarray, List[np.ndarray]]: The outcomes predicted by the context-specific models (n_samples, y_dim). Returned as lists of individual bootstraps if individual_preds is True.
"""
if not hasattr(self, "models") or self.models is None:
raise ValueError(
@@ -392,11 +418,33 @@ def predict(self, C, X, individual_preds=False, **kwargs):
return np.mean(predictions, axis=0)
def predict_params(
- self, C, individual_preds=False, model_includes_mus=True, **kwargs
- ):
+ self,
+ C: np.ndarray,
+ individual_preds: bool = False,
+ model_includes_mus: bool = True,
+ **kwargs,
+ ) -> Union[
+ np.ndarray,
+ List[np.ndarray],
+ Tuple[np.ndarray, np.ndarray],
+ Tuple[List[np.ndarray], List[np.ndarray]],
+ ]:
"""
- :param C:
- :param individual_preds: (Default value = False)
+ Predict context-specific model parameters from context C.
+
+ Args:
+ C (np.ndarray): Context array of shape (n_samples, n_context_features)
+ individual_preds (bool, optional): Whether to return individual model predictions for each bootstrap. Defaults to False, averaging across bootstraps.
+ model_includes_mus (bool, optional): Whether the model includes context-specific offsets (mu). Defaults to True.
+
+ Returns:
+ Union[np.ndarray, List[np.ndarray], Tuple[np.ndarray, np.ndarray], Tuple[List[np.ndarray], List[np.ndarray]]: The parameters of the predicted context-specific models.
+ Returned as lists of individual bootstraps if individual_preds is True, otherwise averages the bootstraps for a better estimate.
+ If model_includes_mus is True, returns both coefficients and offsets as a tuple of (betas, mus). Otherwise, returns coefficients (betas) only.
+ For model_includes_mus=True, ([betas], [mus]) if individual_preds is True, otherwise (betas, mus).
+ For model_includes_mus=False, [betas] if individual_preds is True, otherwise betas.
+ betas is shape (n_samples, x_dim, y_dim) or (n_samples, x_dim) if y_dim = 1.
+ mus is shape (n_samples, y_dim) or (n_samples,) if y_dim = 1.
"""
# Returns betas, mus
if kwargs.pop("uses_y", True):
@@ -423,13 +471,25 @@ def predict_params(
return np.mean(betas, axis=0)
return betas
- def fit(self, *args, **kwargs):
+ def fit(self, *args, **kwargs) -> None:
"""
- Fit model to data.
- Requires numpy arrays C, X, with optional Y.
- If target Y is not given, then X is assumed to be the target.
- :param *args: C, X, Y (optional)
- :param **kwargs:
+ Fit contextualized model to data.
+
+ Args:
+ C (np.ndarray): Context array of shape (n_samples, n_context_features)
+ X (np.ndarray): Predictor array of shape (N, n_features)
+ Y (np.ndarray, optional): Target array of shape (N, n_targets). Defaults to None, where X will be used as targets such as in Contextualized Networks.
+ max_epochs (int, optional): Maximum number of epochs to train for. Defaults to 1.
+ learning_rate (float, optional): Learning rate for optimizer. Defaults to 1e-3.
+ val_split (float, optional): Proportion of data to use for validation and early stopping. Defaults to 0.2.
+ n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1.
+ train_batch_size (int, optional): Batch size for training. Defaults to 1.
+ val_batch_size (int, optional): Batch size for validation. Defaults to 16.
+ test_batch_size (int, optional): Batch size for testing. Defaults to 16.
+ es_patience (int, optional): Number of epochs to wait before early stopping. Defaults to 1.
+ es_monitor (str, optional): Metric to monitor for early stopping. Defaults to "val_loss".
+ es_mode (str, optional): Mode for early stopping. Defaults to "min".
+ es_verbose (bool, optional): Whether to print early stopping updates. Defaults to False.
"""
self.models = []
self.trainers = []
@@ -469,7 +529,9 @@ def fit(self, *args, **kwargs):
for f in organized_kwargs["trainer"]["callback_constructors"]
]
del my_trainer_kwargs["callback_constructors"]
- trainer = self.trainer_constructor(**my_trainer_kwargs, enable_progress_bar=False)
+ trainer = self.trainer_constructor(
+ **my_trainer_kwargs, enable_progress_bar=False
+ )
checkpoint_callback = my_trainer_kwargs["callbacks"][1]
os.makedirs(checkpoint_callback.dirpath, exist_ok=True)
try:
diff --git a/contextualized/modules.py b/contextualized/modules.py
index fc69678..a880f26 100644
--- a/contextualized/modules.py
+++ b/contextualized/modules.py
@@ -143,11 +143,7 @@ class Linear(nn.Module):
Linear encoder
"""
- def __init__(
- self,
- input_dim,
- output_dim
- ):
+ def __init__(self, input_dim, output_dim):
super().__init__()
self.linear = MLP(
input_dim, output_dim, width=output_dim, layers=0, activation=None
@@ -158,11 +154,7 @@ def forward(self, X):
return self.linear(X)
-ENCODERS = {
- "mlp": MLP,
- "ngam": NGAM,
- "linear": Linear
-}
+ENCODERS = {"mlp": MLP, "ngam": NGAM, "linear": Linear}
if __name__ == "__main__":
diff --git a/contextualized_logo.png b/contextualized_logo.png
deleted file mode 100644
index 5acce57..0000000
Binary files a/contextualized_logo.png and /dev/null differ
diff --git a/dev_requirements.txt b/dev_requirements.txt
new file mode 100644
index 0000000..7071e1b
--- /dev/null
+++ b/dev_requirements.txt
@@ -0,0 +1,22 @@
+# This file specifies extra dependencies for the development of Contextualized ML
+
+# Style
+black==23.12.1
+pylint==2.15.5
+pylint-badge @ git+https://github.com/blengerich/pylint-badge
+
+# Documentation
+jupyter-book==0.15.1
+myst-parser==0.18.1
+Sphinx==5.0.2
+sphinx-book-theme==1.0.1
+sphinx-jupyterbook-latex==0.5.2
+sphinx-rtd-theme==2.0.0
+sphinx_external_toc==0.3.1
+
+# Testing
+pytest==7.4.3
+
+# Packaging
+toml==0.10.2
+tomli==2.0.1
diff --git a/docs/_config.yml b/docs/_config.yml
index 2b7a4f1..631c484 100644
--- a/docs/_config.yml
+++ b/docs/_config.yml
@@ -12,10 +12,11 @@ url: "https://contextualized.ml"
# Force re-execution of notebooks on each build.
# See https://jupyterbook.org/content/execute.html
execute:
- execute_notebooks: force
- timeout: 300
+ execute_notebooks: 'off'
+ # execute_notebooks: force
+ # timeout: 300
-only_build_toc_files: true
+only_build_toc_files: false
# Define the name of the latex output file for PDF builds
latex:
@@ -37,3 +38,27 @@ repository:
html:
use_issues_button: true
use_repository_button: true
+
+# https://jupyterbook.org/en/stable/advanced/developers.html
+sphinx:
+ extra_extensions:
+ - 'sphinx.ext.autodoc'
+ - 'sphinx.ext.napoleon'
+ - 'sphinx.ext.viewcode'
+ - 'sphinx.ext.autosummary'
+ config:
+ add_module_names: False
+ autosummary_generate: True
+ html_theme: sphinx_book_theme
+ # templates_path: ['_templates']
+ # - sphinx.ext.duration
+# - sphinx.ext.doctest
+# - sphinx.ext.intersphinx
+# - nbsphinx
+# - myst_parser
+ # 'sphinx.ext.doctest',
+ # 'sphinx.ext.autodoc',
+ # 'sphinx.ext.autosummary',
+ # 'sphinx.ext.intersphinx',
+ # 'nbsphinx',
+ # 'myst_parser',
\ No newline at end of file
diff --git a/docs/_toc.yml b/docs/_toc.yml
index 16307a3..87accbe 100644
--- a/docs/_toc.yml
+++ b/docs/_toc.yml
@@ -25,6 +25,8 @@ parts:
- caption: Demos
chapters:
- file: demos/custom_models
- - file: demos/robust-outliers
- file: demos/benefits
-
+ - caption: API Reference
+ chapters:
+ - file: source/easy
+ - file: source/analysis
\ No newline at end of file
diff --git a/docs/conf.py b/docs/conf.py
new file mode 100644
index 0000000..687060f
--- /dev/null
+++ b/docs/conf.py
@@ -0,0 +1,37 @@
+###############################################################################
+# Auto-generated by `jupyter-book config`
+# If you wish to continue using _config.yml, make edits to that file and
+# re-generate this one.
+###############################################################################
+add_module_names = False
+author = 'Contextualized.ML Team'
+autosummary_generate = True
+bibtex_bibfiles = ['references.bib']
+comments_config = {'hypothesis': False, 'utterances': False}
+copyright = '2023'
+exclude_patterns = ['**.ipynb_checkpoints', '.DS_Store', 'Thumbs.db', '_build']
+extensions = ['sphinx_togglebutton', 'sphinx_copybutton', 'myst_nb', 'jupyter_book', 'sphinx_thebe', 'sphinx_comments', 'sphinx_external_toc', 'sphinx.ext.intersphinx', 'sphinx_design', 'sphinx_book_theme', 'sphinx.ext.autodoc', 'sphinx.ext.napoleon', 'sphinx.ext.viewcode', 'sphinx.ext.autosummary', 'sphinxcontrib.bibtex', 'sphinx_jupyterbook_latex']
+external_toc_exclude_missing = False
+external_toc_path = '_toc.yml'
+html_baseurl = ''
+html_favicon = ''
+html_logo = 'logo.png'
+html_sourcelink_suffix = ''
+html_theme = 'sphinx_book_theme'
+html_theme_options = {'search_bar_text': 'Search this book...', 'launch_buttons': {'notebook_interface': 'classic', 'binderhub_url': '', 'jupyterhub_url': '', 'thebe': False, 'colab_url': ''}, 'path_to_docs': '', 'repository_url': 'https://github.com/cnellington/contextualized', 'repository_branch': 'master', 'extra_footer': '', 'home_page_in_toc': True, 'announcement': '', 'analytics': {'google_analytics_id': ''}, 'use_repository_button': True, 'use_edit_page_button': False, 'use_issues_button': True}
+html_title = 'Contextualized.ML Documentation'
+latex_engine = 'pdflatex'
+myst_enable_extensions = ['colon_fence', 'dollarmath', 'linkify', 'substitution', 'tasklist']
+myst_url_schemes = ['mailto', 'http', 'https']
+nb_execution_allow_errors = False
+nb_execution_cache_path = ''
+nb_execution_excludepatterns = []
+nb_execution_in_temp = False
+nb_execution_mode = 'off'
+nb_execution_timeout = 30
+nb_output_stderr = 'show'
+numfig = True
+pygments_style = 'sphinx'
+suppress_warnings = ['myst.domains']
+use_jupyterbook_latex = True
+use_multitoc_numbering = True
diff --git a/docs/models/easy_bayesian_networks_factors.ipynb b/docs/models/easy_bayesian_networks_factors.ipynb
index 54880e2..9cd0277 100644
--- a/docs/models/easy_bayesian_networks_factors.ipynb
+++ b/docs/models/easy_bayesian_networks_factors.ipynb
@@ -5,7 +5,7 @@
"id": "6e32bc2f",
"metadata": {},
"source": [
- "# Contextualized Bayesian Networks"
+ "# Low-dimensional Contextualized Bayesian Networks"
]
},
{
@@ -259,7 +259,7 @@
"outputs": [
{
"data": {
- "image/png": "\n",
+ "image/png": "",
"text/plain": [
"