Skip to content

Commit

Permalink
add correlation.py with marchenko_pastur plot
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Mar 4, 2021
1 parent 494ac24 commit 2e6665c
Show file tree
Hide file tree
Showing 14 changed files with 1,255 additions and 4 deletions.
1 change: 1 addition & 0 deletions assets/marchenko_pastur.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions assets/marchenko_pastur_rank_deficient.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion assets/ptable_elemental_prevalence.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion assets/ptable_elemental_prevalence_log.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion assets/ptable_elemental_ratio.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion assets/ptable_elemental_ratio_log.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
600 changes: 600 additions & 0 deletions data/rand_tall_matrix.csv

Large diffs are not rendered by default.

500 changes: 500 additions & 0 deletions data/rand_wide_matrix.csv

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions mlmatrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .correlation import marchenko_pastur, marchenko_pastur_pdf
from .cumulative import add_dropdown, cum_err, cum_res
from .elements import (
count_elements,
Expand Down
92 changes: 92 additions & 0 deletions mlmatrics/correlation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from numpy.typing import ArrayLike as Array


def marchenko_pastur_pdf(x: float, gamma: float, sigma: float = 1) -> float:
"""The Marchenko-Pastur probability density function describes the
distribution of singular values of large rectangular random matrices.
See https://wikipedia.org/wiki/Marchenko-Pastur_distribution.
By comparing the eigenvalue distribution of a correlation matrix to this
PDF, one can gauge the significance of correlations.
Args:
x (float): Position at which to compute probability density.
gamma (float): Also referred to as lambda. The distribution's main parameter
that measures how well sampled the data is.
sigma (float, optional): Standard deviation of random variables assumed
to be independent identically distributed. Defaults to 1 as
appropriate for correlation matrices.
Returns:
float: Marchenko-Pastur density for given gamma at x
"""
lambda_m = (sigma * (1 - np.sqrt(1 / gamma))) ** 2 # Largest eigenvalue
lambda_p = (sigma * (1 + np.sqrt(1 / gamma))) ** 2 # Smallest eigenvalue

prefac = gamma / (2 * np.pi * sigma ** 2 * x)
root = np.sqrt((lambda_p - x) * (x - lambda_m))
unit_step = x > lambda_p or x < lambda_m

return prefac * root * (0 if unit_step else 1)


def marchenko_pastur(
matrix: Array,
gamma: float,
sigma: float = 1,
filter_high_evals: bool = False,
ax: Axes = None,
) -> None:
"""Plot the eigenvalue distribution of a symmetric matrix (usually a correlation
matrix) against the Marchenko Pastur distribution.
The probability of a random matrix having eigenvalues larger than (1 + sqrt(gamma))^2
in the absence of any signal is vanishingly small. Thus, if eigenvalues larger than
that appear, they correspond to statistically significant signals.
Args:
matrix (Array): 2d array
gamma (float): The Marchenko-Pastur ratio of random variables to observation
count. E.g. for N=1000 variables and p=500 observations of each,
gamma = p/N = 1/2.
sigma (float, optional): Standard deviation of random variables. Defaults to 1.
filter_high_evals (bool, optional): Whether to filter out eigenvalues larger
than the theoretical random maximum. Useful for focusing the plot on the area
of the MP PDF. Defaults to False.
ax (Axes, optional): plt axes. Defaults to None.
"""
if ax is None:
ax = plt.gca()

# use eigh for speed since correlation matrix is symmetric
evals, _ = np.linalg.eigh(matrix)

lambda_m = (sigma * (1 - np.sqrt(1 / gamma))) ** 2 # Largest eigenvalue
lambda_p = (sigma * (1 + np.sqrt(1 / gamma))) ** 2 # Smallest eigenvalue

if filter_high_evals:
# Remove eigenvalues larger than those expected in a purely random matrix
evals = evals[evals <= lambda_p + 1]

ax.hist(evals, bins=50, edgecolor="black", density=True)

# Plot the theoretical density
mp_pdf = np.vectorize(lambda x: marchenko_pastur_pdf(x, gamma, sigma))
x = np.linspace(max(1e-4, lambda_m), lambda_p, 200)
ax.plot(x, mp_pdf(x), linewidth=5)

# Compute and display matrix rank
# A ratio less than one indicates an undersampled set of RVs
rank = np.linalg.matrix_rank(matrix)
n_rows = matrix.shape[0]

plt.text(
*[0.95, 0.9],
f"rank deficiency: {rank}/{n_rows} {'(None)' if n_rows == rank else ''}",
transform=ax.transAxes,
ha="right",
)
8 changes: 8 additions & 0 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,14 @@ See [`mlmatrics/relevance.py`](mlmatrics/relevance.py).
| :-------------------------------------------------------: | :--------------------------------------------------------------------: |
| ![roc_curve](assets/roc_curve.svg) | ![precision_recall_curve](assets/precision_recall_curve.svg) |

## Correlation

See [`mlmatrics/correlation.py`](mlmatrics/correlation.py).

| [`marchenko_pastur(corr_mat, gamma=ncols/nrows)`](mlmatrics/correlation.py) | [`marchenko_pastur(corr_mat_rank_deficient, gamma=ncols/nrows)`](mlmatrics/correlation.py) |
| :-------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------: |
| ![marchenko_pastur](assets/marchenko_pastur.svg) | ![marchenko_pastur_rank_deficient](assets/marchenko_pastur_rank_deficient.svg) |

## Histograms

See [`mlmatrics/histograms.py`](mlmatrics/histograms.py).
Expand Down
12 changes: 12 additions & 0 deletions scripts/gen_rand_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,15 @@
{"y_binary": y_binary, "y_proba": y_proba, "y_pred": y_proba.round()}
)
df_clf.to_csv(f"{ROOT}/data/rand_clf.csv", index=False, float_format="%g")

r_rows, n_cols = 500, 1000
rand_mat = np.random.normal(0, 1, size=(r_rows, n_cols))
pd.DataFrame(rand_mat).to_csv(
f"{ROOT}/data/rand_wide_matrix.csv", index=False, float_format="%g", header=False
)

r_rows, n_cols = 600, 500
rand_mat = np.random.normal(0, 1, size=(r_rows, n_cols))
pd.DataFrame(rand_mat).to_csv(
f"{ROOT}/data/rand_tall_matrix.csv", index=False, float_format="%g", header=False
)
15 changes: 15 additions & 0 deletions scripts/plot_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
density_scatter_with_hist,
err_decay,
hist_elemental_prevalence,
marchenko_pastur,
precision_recall_curve,
ptable_elemental_prevalence,
ptable_elemental_ratio,
Expand Down Expand Up @@ -156,3 +157,17 @@ def savefig(filename: str) -> None:
# %% Histogram Plots
residual_hist(y_true, y_pred)
savefig("residual_hist")


# %% Correlation Plots
rand_wide_mat = pd.read_csv(f"{ROOT}/data/rand_wide_matrix.csv", header=None).to_numpy()
r_rows, n_cols = rand_wide_mat.shape
corr_mat = np.corrcoef(rand_wide_mat)
marchenko_pastur(corr_mat, gamma=n_cols / r_rows)
savefig("marchenko_pastur")

rand_tall_mat = pd.read_csv(f"{ROOT}/data/rand_tall_matrix.csv", header=None).to_numpy()
r_rows, n_cols = rand_tall_mat.shape
corr_mat_rank_deficient = np.corrcoef(rand_tall_mat)
marchenko_pastur(corr_mat_rank_deficient, gamma=n_cols / r_rows)
savefig("marchenko_pastur_rank_deficient")
21 changes: 21 additions & 0 deletions tests/test_correlation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import numpy as np

from mlmatrics import marchenko_pastur


def test_marchenko_pastur():

r_rows, n_cols = 50, 100
rand_mat = np.random.normal(0, 1, size=(r_rows, n_cols))
corr_mat = np.corrcoef(rand_mat)

marchenko_pastur(corr_mat, gamma=n_cols / r_rows)


def test_marchenko_pastur_filter_high_evals():

r_rows, n_cols = 50, 100
rand_mat = np.random.normal(0, 1, size=(r_rows, n_cols))
corr_mat = np.corrcoef(rand_mat)

marchenko_pastur(corr_mat, gamma=n_cols / r_rows, filter_high_evals=True)

0 comments on commit 2e6665c

Please sign in to comment.