Skip to content

Commit

Permalink
Add new plotly function sankey_from_2_df_cols() (#37)
Browse files Browse the repository at this point in the history
* ptable_heatmap_plotly() rename kwarg hover_cols -> hover_props

* add plotly function sankey_from_2_df_cols() + tests

    lives in new module pymatviz/sankey.py

* move save_and_compress_svg() to pymatviz/utils.py

    was in examples/generate_assets.py before

* add examples for sankey_from_2_df_cols() to readme and examples/matbench_dielectric_eda.ipynb

* update test_ptable_heatmap_plotly() hover_cols -> hover_props
  • Loading branch information
janosh authored May 21, 2022
1 parent 0c8ff8f commit b73dbfa
Show file tree
Hide file tree
Showing 13 changed files with 65,116 additions and 57 deletions.
File renamed without changes
1 change: 1 addition & 0 deletions assets/sankey_from_2_df_cols_randints.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions assets/sankey_spglib_vs_aflow.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
41 changes: 11 additions & 30 deletions examples/generate_assets.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
# %%
from shutil import which
from subprocess import call

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matminer.datasets import load_dataset
from plotly.graph_objs._figure import Figure

from pymatviz.correlation import marchenko_pastur
from pymatviz.cumulative import cum_err, cum_res
Expand All @@ -28,9 +24,10 @@
from pymatviz.quantile import qq_gaussian
from pymatviz.ranking import err_decay
from pymatviz.relevance import precision_recall_curve, roc_curve
from pymatviz.sankey import sankey_from_2_df_cols
from pymatviz.struct_vis import plot_structure_2d
from pymatviz.sunburst import spacegroup_sunburst
from pymatviz.utils import ROOT
from pymatviz.utils import ROOT, save_and_compress_svg


# %%
Expand Down Expand Up @@ -69,29 +66,6 @@
y_std = np.sqrt(y_var_ale + y_var_epi)


def save_and_compress_svg(filename: str, fig: Figure | None = None) -> None:
"""Save Plotly figure as SVG and HTML to assets/ folder. Compresses SVG file with
svgo CLI if available in PATH.
Args:
fig (Figure): Plotly Figure instance.
filename (str): Name of SVG file (w/o extension).
"""
assert not filename.endswith(".svg"), f"{filename = } should not include .svg"
filepath = f"{ROOT}/assets/{filename}.svg"

if isinstance(fig, Figure):
fig.write_image(filepath)
elif fig is None:
plt.savefig(filepath, bbox_inches="tight")
plt.close()
else:
raise TypeError(f"{fig = } should be a Plotly Figure or Matplotlib Figure")

if (svgo := which("svgo")) is not None:
call([svgo, "--multipass", filepath])


# %% Parity Plots
density_scatter(y_pred, y_true)
save_and_compress_svg("density_scatter")
Expand Down Expand Up @@ -146,7 +120,7 @@ def save_and_compress_svg(filename: str, fig: Figure | None = None) -> None:
# %% Plotly interactive periodic table heatmap
fig = ptable_heatmap_plotly(
df_ptable.atomic_mass,
hover_cols=["atomic_mass", "atomic_number"],
hover_props=["atomic_mass", "atomic_number"],
hover_data="density = " + df_ptable.density.astype(str) + " g/cm^3",
)
fig.update_layout(
Expand Down Expand Up @@ -268,4 +242,11 @@ def save_and_compress_svg(filename: str, fig: Figure | None = None) -> None:
ax = plot_structure_2d(struct, ax=ax)
ax.set_title(struct.composition.reduced_formula)

save_and_compress_svg("mp-structures-2d", fig)
save_and_compress_svg("mp_structures_2d", fig)


# %% Sankey diagram of random integers
col_names = "col_a col_b".split()
df = pd.DataFrame(np.random.randint(1, 6, size=(100, 2)), columns=col_names)
fig = sankey_from_2_df_cols(df, col_names, labels_with_counts="percent")
save_and_compress_svg("sankey-from-2-df-cols-randints", fig)
2 changes: 1 addition & 1 deletion examples/matbench_dielectric_eda.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1804,7 +1804,7 @@
}
],
"source": [
"fig = spacegroup_sunburst(df_diel.spg_num, show_values=\"percent\")\n",
"fig = spacegroup_sunburst(df_diel.spg_num, show_counts=\"percent\")\n",
"fig.update_layout(title=\"Space group sunburst\")\n",
"fig.write_image(\"dielectric-spacegroup-sunburst.pdf\")\n",
"fig.show()"
Expand Down
64,931 changes: 64,930 additions & 1 deletion examples/matbench_perovskites.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pymatviz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .quantile import qq_gaussian
from .ranking import err_decay
from .relevance import precision_recall_curve, roc_curve
from .sankey import sankey_from_2_df_cols
from .struct_vis import plot_structure_2d
from .sunburst import spacegroup_sunburst
from .utils import ROOT, add_mae_r2_box, annotate_bars
Loading

0 comments on commit b73dbfa

Please sign in to comment.