Skip to content

Commit

Permalink
fully type annotate all functions and fix mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Jul 3, 2021
1 parent 168382d commit b5729e3
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 80 deletions.
67 changes: 30 additions & 37 deletions ml_matrics/elements.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Union
from typing import Any, Sequence, Union, cast

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -12,7 +12,9 @@
from ml_matrics.utils import ROOT, annotate_bar_heights


def count_elements(formulas: list) -> pd.Series:
def count_elements(
formulas: Sequence[str] = None, elem_counts: pd.Series = None
) -> pd.Series:
"""Count occurrences of each chemical element in a materials dataset.
Args:
Expand All @@ -21,6 +23,18 @@ def count_elements(formulas: list) -> pd.Series:
Returns:
pd.Series: Total number of appearances of each element in `formulas`.
"""

if (formulas is None and elem_counts is None) or (
formulas is not None and elem_counts is not None
):
raise ValueError("provide either formulas or elem_counts, not neither nor both")

# elem_counts is sure to be a Series at this point but mypy needs help to realize
elem_counts = cast(pd.Series, elem_counts)

if formulas is None:
return elem_counts

formula2dict = lambda str: pd.Series(
Composition(str).fractional_composition.as_dict()
)
Expand All @@ -36,7 +50,7 @@ def count_elements(formulas: list) -> pd.Series:


def ptable_elemental_prevalence(
formulas: List[str] = None,
formulas: Sequence[str] = None,
elem_counts: pd.Series = None,
log: bool = False,
ax: Axes = None,
Expand All @@ -63,16 +77,11 @@ def ptable_elemental_prevalence(
Raises:
ValueError: provide either formulas or elem_counts, not neither nor both
"""
if (formulas is None and elem_counts is None) or (
formulas is not None and elem_counts is not None
):
raise ValueError("provide either formulas or elem_counts, not neither nor both")

if formulas is not None:
elem_counts = count_elements(formulas)
elem_counts = count_elements(formulas, elem_counts)

ptable = pd.read_csv(f"{ROOT}/ml_matrics/elements.csv")
cmap = get_cmap(cmap)
color_map = get_cmap(cmap)

n_rows = ptable.row.max()
n_columns = ptable.column.max()
Expand All @@ -87,6 +96,7 @@ def ptable_elemental_prevalence(

norm = LogNorm() if log else Normalize()

# replace positive and negative infinities with NaN values, then drop all NaNs
clean_scale = elem_counts.replace([np.inf, -np.inf], np.nan).dropna()

if cbar_max is not None:
Expand All @@ -111,7 +121,7 @@ def ptable_elemental_prevalence(
color = "white" # not in either formulas_a nor formulas_b
count_label = "0/0"
else:
color = cmap(norm(count)) if count > 0 else "silver"
color = color_map(norm(count)) if count > 0 else "silver"
# replace shortens scientific notation 1e+01 to 1e1 so it fits inside cells
count_label = f"{count:.2g}".replace("e+0", "e")

Expand Down Expand Up @@ -148,11 +158,11 @@ def ptable_elemental_prevalence(


def ptable_elemental_ratio(
formulas_a: List[str] = None,
formulas_b: List[str] = None,
formulas_a: Sequence[str] = None,
formulas_b: Sequence[str] = None,
elem_counts_a: pd.Series = None,
elem_counts_b: pd.Series = None,
**kwargs,
**kwargs: Any,
) -> None:
"""Display the ratio of the normalised prevalence of each element for two sets of
compositions.
Expand All @@ -165,32 +175,15 @@ def ptable_elemental_ratio(
kwargs (dict, optional): kwargs passed to ptable_elemental_prevalence
"""

if (formulas_a is None and elem_counts_a is None) or (
formulas_a is not None and elem_counts_a is not None
):
raise ValueError(
"provide either formulas_a or elem_counts_a, not neither nor both"
)

if (formulas_b is None and elem_counts_b is None) or (
formulas_b is not None and elem_counts_b is not None
):
raise ValueError(
"provide either formulas_b or elem_counts_b, not neither nor both"
)
elem_counts_a = count_elements(formulas_a, elem_counts_a)

if formulas_a is not None:
elem_counts_a = count_elements(formulas_a)
elem_counts_b = count_elements(formulas_b, elem_counts_b)

if formulas_b is not None:
elem_counts_b = count_elements(formulas_b)
elem_counts = elem_counts_a / elem_counts_b

# normalize elemental distributions, just a scaling factor but
# makes different ratio plots comparable
elem_counts_a /= elem_counts_a.sum()
elem_counts_b /= elem_counts_b.sum()

elem_counts = elem_counts_a / elem_counts_b
elem_counts /= elem_counts.sum()

ptable_elemental_prevalence(
elem_counts=elem_counts, cbar_title="Element Ratio", **kwargs
Expand All @@ -207,12 +200,12 @@ def ptable_elemental_ratio(


def hist_elemental_prevalence(
formulas: list,
formulas: Sequence[str],
log: bool = False,
keep_top: int = None,
ax: Axes = None,
bar_values: str = "percent",
**kwargs,
**kwargs: Any,
) -> None:
"""Plots a histogram of the prevalence of each element in a materials dataset.
Expand Down
40 changes: 21 additions & 19 deletions ml_matrics/histograms.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any, Dict, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
Expand All @@ -9,7 +11,7 @@


def residual_hist(
y_true: Array, y_pred: Array, ax: Axes = None, xlabel: str = None, **kwargs
y_true: Array, y_pred: Array, ax: Axes = None, xlabel: str = None, **kwargs: Any
) -> Axes:
"""Plot the residual distribution overlayed with a Gaussian kernel
density estimate.
Expand Down Expand Up @@ -55,7 +57,7 @@ def true_pred_hist(
bins: int = 50,
log: bool = True,
truth_color: str = "blue",
**kwargs,
**kwargs: Any,
) -> Axes:
"""Plot a histogram of model predictions with bars colored by the average uncertainty of
predictions in that bin. Overlayed by a more transparent histogram of ground truth values.
Expand All @@ -77,36 +79,36 @@ def true_pred_hist(
if ax is None:
ax = plt.gca()

cmap = getattr(plt.cm, cmap)
color_map = getattr(plt.cm, cmap)
y_true, y_pred, y_std = np.array([y_true, y_pred, y_std])

_, bins, bars = ax.hist(
_, bin_edges, bars = ax.hist(
y_pred, bins=bins, alpha=0.8, label=r"$y_\mathrm{pred}$", **kwargs
)
ax.figure.set
ax.hist(
y_true,
bins=bins,
bins=bin_edges,
alpha=0.2,
color=truth_color,
label=r"$y_\mathrm{true}$",
**kwargs,
)

for xmin, xmax, rect in zip(bins, bins[1:], bars.patches):
for xmin, xmax, rect in zip(bin_edges, bin_edges[1:], bars.patches):

y_preds_in_rect = np.logical_and(y_pred > xmin, y_pred < xmax).nonzero()

color_value = y_std[y_preds_in_rect].mean()

rect.set_color(cmap(color_value))
rect.set_color(color_map(color_value))

if log:
plt.yscale("log")
ax.legend(frameon=False)

norm = plt.cm.colors.Normalize(vmax=y_std.max(), vmin=y_std.min())
cbar = plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), pad=0.075)
cbar = plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=color_map), pad=0.075)
cbar.outline.set_linewidth(1)
cbar.set_label(r"mean $y_\mathrm{std}$ of prediction in bin")
cbar.ax.yaxis.set_ticks_position("left")
Expand All @@ -116,7 +118,7 @@ def true_pred_hist(
return ax


def spacegroup_hist(spacegroups: Array, ax: Axes = None, **kwargs) -> Axes:
def spacegroup_hist(spacegroups: Array, ax: Axes = None, **kwargs: Any) -> Axes:
"""Plot a histogram of spacegroups shaded by crystal system.
(triclinic, monoclinic, orthorhombic, tetragonal, trigonal, hexagonal, cubic)
Expand All @@ -140,21 +142,21 @@ def spacegroup_hist(spacegroups: Array, ax: Axes = None, **kwargs) -> Axes:
trans = transforms.blended_transform_factory(ax.transData, ax.transAxes)

# https://git.io/JYJcs
crystal_systems = {
"tri-/monoclinic": ["red", (1, 15)],
"orthorhombic": ["blue", (16, 74)],
"tetragonal": ["green", (75, 142)],
"trigonal": ["orange", (143, 167)],
"hexagonal": ["purple", (168, 194)],
"cubic": ["yellow", (195, 230)],
crystal_systems: Dict[str, Tuple[str, Tuple[int, int]]] = {
"tri-/monoclinic": ("red", (1, 15)),
"orthorhombic": ("blue", (16, 74)),
"tetragonal": ("green", (75, 142)),
"trigonal": ("orange", (143, 167)),
"hexagonal": ("purple", (168, 194)),
"cubic": ("yellow", (195, 230)),
}

for name, [color, rng] in crystal_systems.items():
x0, x1 = rng
for name, [color, x_lim] in crystal_systems.items():
x0, x1 = x_lim
for patch in ax.patches[0 if x0 == 1 else x0 : x1 + 1]:
patch.set_facecolor(color)
ax.text(
sum(rng) / 2,
sum(x_lim) / 2,
0.95,
name,
rotation=90,
Expand Down
4 changes: 2 additions & 2 deletions ml_matrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

def regression_metrics(
y_true: Array, y_preds: Array, verbose: bool = False
) -> Dict[str, Union[float, dict]]:
) -> Dict[str, Union[float, Dict[str, float]]]:
"""Print a common selection of regression metrics
TODO make robust by finding the common axis
Expand Down Expand Up @@ -98,7 +98,7 @@ def regression_metrics(

def classification_metrics(
target: Array, logits: Array, average: str = "micro", verbose: bool = False
) -> Dict[str, Union[float, dict]]:
) -> Dict[str, Union[float, Dict[str, float]]]:
"""print out metrics for a classification task
TODO make less janky, first index is for ensembles, second data, third classes.
Expand Down
29 changes: 20 additions & 9 deletions ml_matrics/parity.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Tuple
from typing import Any, Tuple

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes
from matplotlib.gridspec import GridSpec
from matplotlib.offsetbox import AnchoredText
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from numpy import ndarray as Array
Expand All @@ -15,7 +16,7 @@

def hist_density(
xs: Array, ys: Array, sort: bool = True, bins: int = 100
) -> Tuple[Array]:
) -> Tuple[Array, Array, Array]:
"""Return an approximate density of 2d points.
Args:
Expand Down Expand Up @@ -69,7 +70,7 @@ def density_scatter(
ylabel: str = "Predicted",
identity: bool = True,
stats: bool = True,
**kwargs,
**kwargs: Any,
) -> Axes:
"""Scatter plot colored (and optionally sorted) by density.
Expand Down Expand Up @@ -119,7 +120,7 @@ def scatter_with_err_bar(
xlabel: str = "Actual",
ylabel: str = "Predicted",
title: str = None,
**kwargs,
**kwargs: Any,
) -> Axes:
"""Scatter plot with optional x- and/or y-error bars. Useful when passing model
uncertainties as yerr=y_std for checking if uncertainty correlates with error,
Expand Down Expand Up @@ -158,7 +159,7 @@ def density_hexbin(
color_map: Array = None,
xlabel: str = "Actual",
ylabel: str = "Predicted",
):
) -> Axes:
"""Hexagonal-grid scatter plot colored by density or by third dimension
passed color_by"""
if ax is None:
Expand All @@ -175,23 +176,33 @@ def density_hexbin(

ax.set(xlabel=xlabel, ylabel=ylabel)

return ax


def density_scatter_with_hist(xs, ys, cell=None, bins=100, **kwargs):
def density_scatter_with_hist(
xs: Array, ys: Array, cell: GridSpec = None, bins: int = 100, **kwargs: Any
) -> Axes:
"""Scatter plot colored (and optionally sorted) by density
with histograms along each dimension
"""

ax_scatter = with_hist(xs, ys, cell, bins)
density_scatter(xs, ys, ax_scatter, **kwargs)
ax = density_scatter(xs, ys, ax_scatter, **kwargs)

return ax

def density_hexbin_with_hist(xs, ys, cell=None, bins=100, **kwargs):

def density_hexbin_with_hist(
xs: Array, ys: Array, cell: GridSpec = None, bins: int = 100, **kwargs: Any
) -> Axes:
"""Hexagonal-grid scatter plot colored by density or by third dimension
passed color_by with histograms along each dimension.
"""

ax_scatter = with_hist(xs, ys, cell, bins)
density_hexbin(xs, ys, ax_scatter, **kwargs)
ax = density_hexbin(xs, ys, ax_scatter, **kwargs)

return ax


def residual_vs_actual(y_true: Array, y_pred: Array, ax: Axes = None) -> Axes:
Expand Down
8 changes: 5 additions & 3 deletions ml_matrics/quantile.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union
from typing import Dict, Union

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -8,7 +8,9 @@
from ml_matrics.utils import add_identity


def qq_gaussian(y_true: Array, y_pred: Array, y_std: Union[Array, dict]) -> None:
def qq_gaussian(
y_true: Array, y_pred: Array, y_std: Union[Array, Dict[str, Array]]
) -> None:
"""Plot the Gaussian quantile-quantile (Q-Q) plot of one (passed as array)
or multiple (passed as dict) sets of uncertainty estimates for a single
pair of ground truth targets `y_true` and model predictions `y_pred`.
Expand All @@ -28,7 +30,7 @@ def qq_gaussian(y_true: Array, y_pred: Array, y_std: Union[Array, dict]) -> None
y_pred (Array): model predictions
y_std (Array | dict): model uncertainties
"""
if type(y_std) != dict:
if isinstance(y_std, Array):
y_std = {"std": y_std}

res = np.abs(y_pred - y_true)
Expand Down
Loading

0 comments on commit b5729e3

Please sign in to comment.