Skip to content

Commit

Permalink
doc string improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Sep 3, 2021
1 parent 4cf00a0 commit 6572d85
Show file tree
Hide file tree
Showing 10 changed files with 75 additions and 56 deletions.
2 changes: 1 addition & 1 deletion ml_matrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
from .quantile import qq_gaussian
from .ranking import err_decay
from .relevance import precision_recall_curve, roc_curve
from .utils import ROOT, annotate_bar_heights
from .utils import ROOT, add_mae_r2_box, annotate_bar_heights
8 changes: 4 additions & 4 deletions ml_matrics/correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ def marchenko_pastur(
"""Plot the eigenvalue distribution of a symmetric matrix (usually a correlation
matrix) against the Marchenko Pastur distribution.
The probability of a random matrix having eigenvalues larger than (1 + sqrt(gamma))^2
in the absence of any signal is vanishingly small. Thus, if eigenvalues larger than
that appear, they correspond to statistically significant signals.
The probability of a random matrix having eigenvalues >= (1 + sqrt(gamma))^2 in the
absence of any signal is vanishingly small. Thus, if eigenvalues larger than that
appear, they correspond to statistically significant signals.
Args:
matrix (Array): 2d array
Expand All @@ -56,7 +56,7 @@ def marchenko_pastur(
gamma = p/N = 1/2.
sigma (float, optional): Standard deviation of random variables. Defaults to 1.
filter_high_evals (bool, optional): Whether to filter out eigenvalues larger
than the theoretical random maximum. Useful for focusing the plot on the area
than theoretical random maximum. Useful for focusing the plot on the area
of the MP PDF. Defaults to False.
ax (Axes, optional): plt.Axes object. Defaults to None.
"""
Expand Down
39 changes: 23 additions & 16 deletions ml_matrics/elements.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Sequence, Union, cast
from typing import Any, Dict, Literal, Sequence, Union

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -12,13 +12,18 @@
from ml_matrics.utils import ROOT, annotate_bar_heights


PTABLE = pd.read_csv(f"{ROOT}/ml_matrics/elements.csv")


def count_elements(
formulas: Sequence[str] = None, elem_counts: pd.Series = None
formulas: Sequence[str] = None, elem_counts: Union[pd.Series, Dict[str, int]] = None
) -> pd.Series:
"""Count occurrences of each chemical element in a materials dataset.
Args:
formulas (list[str]): compositional strings, e.g. ["Fe2O3", "Bi2Te3"]
formulas (list[str], optional): compositional strings, e.g. ["Fe2O3", "Bi2Te3"]
elem_counts (pd.Series | dict[str, int], optional): map from element symbol to
prevalence counts
Returns:
pd.Series: Total number of appearances of each element in `formulas`.
Expand All @@ -29,8 +34,8 @@ def count_elements(
):
raise ValueError("provide either formulas or elem_counts, not neither nor both")

# elem_counts is sure to be a Series at this point but mypy needs help to realize
elem_counts = cast(pd.Series, elem_counts)
# ensure elem_counts is Series if we got a dict
elem_counts = pd.Series(elem_counts)

if formulas is None:
return elem_counts
Expand All @@ -43,9 +48,8 @@ def count_elements(

# ensure all elements are present in returned Series (with count zero if they
# weren't in formulas)
ptable = pd.read_csv(f"{ROOT}/ml_matrics/elements.csv")
# fill_value=0 required as max(NaN, any int) = NaN
srs = srs.combine(pd.Series(0, index=ptable.symbol), max, fill_value=0)
srs = srs.combine(pd.Series(0, index=PTABLE.symbol), max, fill_value=0)
return srs


Expand Down Expand Up @@ -168,11 +172,14 @@ def ptable_elemental_ratio(
compositions.
Args:
formulas_a (list[str]): numerator compositional strings, e.g. ["Fe2O3", "Bi2Te3"]
formulas_b (list[str]): denominator compositional strings, e.g. ["Fe2O3", "Bi2Te3"]
elem_counts_a (pd.Series): Map from element symbol to prevalence count for numerator
elem_counts_b (pd.Series): Map from element symbol to prevalence count for denominator
kwargs (dict, optional): kwargs passed to ptable_elemental_prevalence
formulas_a (list[str], optional): numerator compositional strings, e.g
["Fe2O3", "Bi2Te3"]
formulas_b (list[str], optional): denominator compositional strings
elem_counts_a (pd.Series | dict[str, int], optional): map from element symbol
to prevalence count for numerator
elem_counts_b (pd.Series | dict[str, int], optional): map from element symbol
to prevalence count for denominator
kwargs (Any, optional): kwargs passed to ptable_elemental_prevalence
"""

elem_counts_a = count_elements(formulas_a, elem_counts_a)
Expand Down Expand Up @@ -204,7 +211,7 @@ def hist_elemental_prevalence(
log: bool = False,
keep_top: int = None,
ax: Axes = None,
bar_values: str = "percent",
bar_values: Literal["percent", "count", None] = "percent",
**kwargs: Any,
) -> None:
"""Plots a histogram of the prevalence of each element in a materials dataset.
Expand All @@ -216,9 +223,9 @@ def hist_elemental_prevalence(
log (bool, optional): Whether y-axis is log or linear. Defaults to False.
keep_top (int | None): Display only the top n elements by prevalence.
ax (Axes): plt.Axes object. Defaults to None.
bar_values (str): One of 'percent', 'count' or None. Annotate bars with the
percentage each element makes up in the total element count, or use the count
itself, or display no bar labels.
bar_values ('percent'|'count'|None): 'percent' annotates bars with the
percentage each element makes up in the total element count. 'count'
displays count itself. None removes bar labels.
**kwargs (int): Keyword arguments passed to annotate_bar_heights.
"""
if ax is None:
Expand Down
10 changes: 6 additions & 4 deletions ml_matrics/histograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,20 @@ def true_pred_hist(
truth_color: str = "blue",
**kwargs: Any,
) -> Axes:
"""Plot a histogram of model predictions with bars colored by the average uncertainty of
predictions in that bin. Overlayed by a more transparent histogram of ground truth values.
"""Plot a histogram of model predictions with bars colored by the mean uncertainty of
predictions in that bin. Overlayed by a more transparent histogram of ground truth
values.
Args:
y_true (array): ground truth targets
y_pred (array): model predictions
y_std (array): model uncertainty
ax (Axes, optional): plt.Axes object. Defaults to None.
cmap (str, optional): string identifier of a plt colormap. Defaults to "hot".
cmap (str, optional): string identifier of a plt colormap. Defaults to 'hot'.
bins (int, optional): Histogram resolution. Defaults to 50.
log (bool, optional): Whether to log-scale the y-axis. Defaults to True.
truth_color (str, optional): Face color to use for y_true bars. Defaults to "blue".
truth_color (str, optional): Face color to use for y_true bars.
Defaults to 'blue'.
Returns:
Axes: plt.Axes object with plotted data.
Expand Down
25 changes: 6 additions & 19 deletions ml_matrics/parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@
import numpy as np
from matplotlib.axes import Axes
from matplotlib.gridspec import GridSpec
from matplotlib.offsetbox import AnchoredText
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from scipy.interpolate import interpn
from sklearn.metrics import r2_score

from ml_matrics.utils import NumArray, with_hist
from ml_matrics.utils import NumArray, add_mae_r2_box, with_hist


def hist_density(
Expand All @@ -26,7 +24,7 @@ def hist_density(
bins (int, optional): Number of bins (histogram resolution). Defaults to 100.
Returns:
tuple[NumArray]: x- and y-coordinates (sorted by density) as well as density itself.
tuple[array]: x- and y-coordinates (sorted by density) as well as density itself
"""

data, x_e, y_e = np.histogram2d(xs, ys, bins=bins)
Expand All @@ -47,18 +45,6 @@ def hist_density(
return xs, ys, zs


def add_mae_r2_box(
xs: NumArray, ys: NumArray, ax: Axes, loc: str = "lower right"
) -> None:

mae_str = f"$\\mathrm{{MAE}} = {np.abs(xs - ys).mean():.3f}$\n"

r2_str = f"$R^2 = {r2_score(xs, ys):.3f}$"

text_box = AnchoredText(mae_str + r2_str, loc=loc, frameon=False)
ax.add_artist(text_box)


def density_scatter(
xs: NumArray,
ys: NumArray,
Expand All @@ -79,11 +65,12 @@ def density_scatter(
xs (array): x values.
ys (array): y values.
ax (Axes, optional): plt.Axes object. Defaults to None.
color_map (str, optional): plt color map or valid string name. Defaults to "Blues".
color_map (str, optional): plt color map or valid string name.
Defaults to "Blues".
sort (bool, optional): Whether to sort the data. Defaults to True.
log (bool, optional): Whether to the color scale. Defaults to True.
density_bins (int, optional): How many density_bins to use for the density histogram,
i.e. granularity of the density color scale. Defaults to 100.
density_bins (int, optional): How many density_bins to use for the density
histogram, i.e. granularity of the density color scale. Defaults to 100.
xlabel (str, optional): x-axis label. Defaults to "Actual".
ylabel (str, optional): y-axis label. Defaults to "Predicted".
identity (bool, optional): Whether to add an identity/parity line (y = x).
Expand Down
2 changes: 1 addition & 1 deletion ml_matrics/quantile.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def qq_gaussian(
The measure of calibration is how well the uncertainty percentiles conform
to those of a normal distribution.
Inspired by https://github.com/uncertainty-toolbox/uncertainty-toolbox#visualizations.
Inspired by https://git.io/JufOz.
Info on Q-Q plots: https://wikipedia.org/wiki/Q-Q_plot
Args:
Expand Down
29 changes: 26 additions & 3 deletions ml_matrics/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from os.path import abspath, dirname
from typing import Sequence, Union
from typing import Any, Sequence, Union

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes
from matplotlib.gridspec import GridSpec
from matplotlib.offsetbox import AnchoredText
from numpy.typing import NDArray
from sklearn.metrics import r2_score


ROOT: str = dirname(dirname(abspath(__file__)))
Expand All @@ -15,7 +17,10 @@


def with_hist(
xs: NumArray, ys: NumArray, cell: GridSpec = None, bins: int = 100 # type: ignore
xs: NDArray[np.float64],
ys: NDArray[np.float64],
cell: GridSpec = None,
bins: int = 100,
) -> Axes:
"""Call before creating a plot and use the returned `ax_main` for all
subsequent plotting ops to create a grid of plots with the main plot in
Expand Down Expand Up @@ -53,7 +58,7 @@ def with_hist(
return ax_main


def softmax(arr: NumArray, axis: int = -1) -> NumArray: # type: ignore
def softmax(arr: NDArray[np.float64], axis: int = -1) -> NDArray[np.float64]:
"""Compute the softmax of an array along an axis."""
exp = np.exp(arr)
return exp / exp.sum(axis=axis, keepdims=True)
Expand Down Expand Up @@ -102,3 +107,21 @@ def annotate_bar_heights(
ax.annotate(label, (x_pos, y_pos), ha="center", fontsize=fontsize)
# ensure enough vertical space to display label above highest bar
ax.margins(y=0.1)


def add_mae_r2_box(
xs: NDArray[np.float64],
ys: NDArray[np.float64],
ax: Axes = None,
loc: str = "lower right",
**kwargs: Any,
) -> None:
if ax is None:
ax = plt.gca()

mae_str = f"$\\mathrm{{MAE}} = {np.abs(xs - ys).mean():.3f}$\n"

r2_str = f"$R^2 = {r2_score(xs, ys):.3f}$"

text_box = AnchoredText(mae_str + r2_str, loc=loc, frameon=False, **kwargs)
ax.add_artist(text_box)
8 changes: 4 additions & 4 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,20 +107,20 @@ When adding new SVG assets, please compress them before committing. This can eit
This project uses [`pytest`](https://docs.pytest.org/en/stable/usage.html). To run the entire test suite:

```sh
python -m pytest
pytest
```

To run individual or groups of test files, pass `pytest` a path or glob pattern, respectively:

```sh
python -m pytest tests/test_cumulative.py
python -m pytest **/test_*_metrics.py
pytest tests/test_cumulative.py
pytest **/test_*_metrics.py
```

To run a single test, pass its name to the `-k` flag:

```sh
python -m pytest -k test_precision_recall_curve
pytest -k test_precision_recall_curve
```

Consult the [`pytest`](https://docs.pytest.org/en/stable/usage.html) docs for more details.
Expand Down
2 changes: 1 addition & 1 deletion scripts/fetch_mp_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


# %%
# requires MP API key in ~/.pmgrc.yml available at https://materialsproject.org/dashboard
# needs MP API key in ~/.pmgrc.yml available at https://materialsproject.org/dashboard
# pmg config --add PMG_MAPI_KEY <your_key>
with MPRester() as mpr:
formulas = mpr.query({"nelements": {"$lt": 2}}, ["pretty_formula"])
Expand Down
6 changes: 3 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ url = https://github.com/janosh/ml-matrics
author = Janosh Riebesell
author_email = [email protected]
license = MIT
license_file = license
license_files = license
keywords = machine-learning, materials-discovery, metrics, plots, visualizations, model-performance
classifiers =
Programming Language :: Python :: 3
Expand Down Expand Up @@ -48,8 +48,8 @@ universal = True

# Tooling Config
[flake8]
# tell flake8 to use black's default line length
max-line-length = 95
# use black's default line length
max-line-length = 88
max-complexity = 16
# E731: do not assign a lambda expression, use a def
# E203: whitespace before ':'
Expand Down

0 comments on commit 6572d85

Please sign in to comment.