Skip to content

Commit

Permalink
Self-import refactor (#194)
Browse files Browse the repository at this point in the history
* import pymatviz pkg as namespace

* bump ruff commit hook + auto fixes

* fix LS fit: y = ... annotations for facet plots all lie on top of each other. they should be positioned in the corner of each subplot

* restore df_ptable importable from top-level pymatviz namespace

* import save_and_compress_svg as pmv.io.save_and_compress_svg

* # Please enter the commit message for your changes. Lines starting
# with '#' will be ignored, and an empty message aborts the commit.
#
# On branch import-refactor
# Your branch is up to date with 'origin/import-refactor'.
#
# Changes to be committed:
#	modified:   examples/dataset_exploration/camd_2022/explore_camd_2022.py
#	modified:   examples/dataset_exploration/matbench/dielectric/explore_dielectric.py
#	modified:   examples/dataset_exploration/matbench/log_g+kvrh/explore_log_g+krvh.py
#	modified:   examples/dataset_exploration/matbench/perovskites/explore_perovskites.py
#	modified:   examples/dataset_exploration/matpes/eda.py
#	modified:   examples/dataset_exploration/wbm/explore_wbm.py
#	modified:   examples/diatomics/plot.py
#	modified:   examples/make_assets/histogram.py
#	modified:   examples/make_assets/ptable/ptable_matplotlib.py
#	modified:   examples/make_assets/ptable/ptable_plotly.py
#	modified:   pymatviz/__init__.py
#	modified:   pymatviz/histogram.py
#	modified:   pymatviz/ptable/ptable_plotly.py
#	modified:   pymatviz/scatter.py
#	modified:   pymatviz/sunburst.py
#	modified:   pymatviz/uncertainty.py
#	modified:   tests/ptable/test_ptable_matplotlib.py
#	modified:   tests/ptable/test_ptable_plotly.py
#	modified:   tests/test_io.py
#	modified:   tests/test_process_data.py
#

* fix circular import

* more import refactor

* good enough

* fix test side effect altering pmv.df_ptable
  • Loading branch information
janosh authored Aug 18, 2024
1 parent 8700d3b commit f4f8a30
Show file tree
Hide file tree
Showing 54 changed files with 590 additions and 650 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.5
rev: v0.6.1
hooks:
- id: ruff
args: [--fix]
Expand Down Expand Up @@ -73,7 +73,7 @@ repos:
exclude: ^(site/src/figs/.+\.svelte|data/wbm/20.+\..+|site/src/(routes|figs).+\.(yaml|json)|changelog.md)$

- repo: https://github.com/pre-commit/mirrors-eslint
rev: v9.8.0
rev: v9.9.0
hooks:
- id: eslint
types: [file]
Expand All @@ -87,6 +87,6 @@ repos:
- typescript-eslint

- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.374
rev: v1.1.376
hooks:
- id: pyright
2 changes: 1 addition & 1 deletion assets/density-scatter-plotly.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
import matplotlib.pyplot as plt
from matminer.datasets import load_dataset

import pymatviz as pmv
from pymatviz import count_elements, ptable_heatmap
from pymatviz.enums import Key
from pymatviz.io import save_fig


# %%
Expand All @@ -45,7 +45,7 @@
count_elements(df_boltz[Key.formula]), log=True, return_type="figure"
)
fig.suptitle("Elements in BoltzTraP MP dataset")
save_fig(fig, "boltztrap_mp-ptable-heatmap.pdf")
pmv.save_fig(fig, "boltztrap_mp-ptable-heatmap.pdf")


# %%
Expand All @@ -54,13 +54,13 @@
return_type="figure",
)
fig.suptitle("Elements of top 100 n-type powerfactors in BoltzTraP MP dataset")
save_fig(fig, "boltztrap_mp-ptable-heatmap-top-100-nPF.pdf")
pmv.save_fig(fig, "boltztrap_mp-ptable-heatmap-top-100-nPF.pdf")


# %%
ax = df_boltz.hist(bins=50, log=True, layout=[2, 3], figsize=[18, 8])
plt.suptitle("BoltzTraP MP")
save_fig(ax, "boltztrap_mp-hists.pdf")
pmv.save_fig(ax, "boltztrap_mp-hists.pdf")


# %%
Expand Down
7 changes: 3 additions & 4 deletions examples/dataset_exploration/camd_2022/explore_camd_2022.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@
import requests
from pymatgen.symmetry.groups import SpaceGroup

import pymatviz as pmv
from pymatviz import count_elements, ptable_heatmap, spacegroup_sunburst
from pymatviz.enums import Key
from pymatviz.io import save_fig
from pymatviz.powerups import annotate_bars


# %% Download data (if needed)
Expand All @@ -48,12 +47,12 @@
elem_counts = count_elements(df_camd.reduced_formula)
fig = ptable_heatmap(elem_counts, log=True, return_type="figure")
fig.suptitle("Elements in CAMD 2022 dataset")
save_fig(fig, "camd-2022-ptable-heatmap.pdf")
pmv.save_fig(fig, "camd-2022-ptable-heatmap.pdf")


# %%
ax = df_camd.data_source.value_counts().plot.bar(fontsize=18, rot=0)
annotate_bars(ax, v_offset=3e3)
pmv.powerups.annotate_bars(ax, v_offset=3e3)


# %%
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,8 @@
from matminer.datasets import load_dataset
from tqdm import tqdm

from pymatviz import (
count_elements,
crystal_sys_order,
ptable_heatmap,
ptable_heatmap_plotly,
spacegroup_bar,
spacegroup_sunburst,
)
import pymatviz as pmv
from pymatviz.enums import Key
from pymatviz.io import save_fig
from pymatviz.utils import crystal_sys_from_spg_num


"""matbench_dielectric dataset
Expand Down Expand Up @@ -48,37 +39,39 @@
]
df_diel[Key.n_wyckoff] = df_diel.wyckoff.map(count_wyckoff_positions)

df_diel[Key.crystal_system] = df_diel[Key.spg_num].map(crystal_sys_from_spg_num)
df_diel[Key.crystal_system] = df_diel[Key.spg_num].map(
pmv.utils.crystal_sys_from_spg_num
)

df_diel[Key.volume] = [x.volume for x in df_diel[Key.structure]]
df_diel[Key.formula] = [x.formula for x in df_diel[Key.structure]]


# %%
fig = ptable_heatmap(
count_elements(df_diel[Key.formula]), log=True, return_type="figure"
fig = pmv.ptable_heatmap(
pmv.count_elements(df_diel[Key.formula]), log=True, return_type="figure"
)
fig.suptitle("Elemental prevalence in the Matbench dielectric dataset")
save_fig(fig, "dielectric-ptable-heatmap.pdf")
pmv.save_fig(fig, "dielectric-ptable-heatmap.pdf")


# %%
fig = ptable_heatmap_plotly(df_diel[Key.formula], log=True, colorscale="viridis")
fig = pmv.ptable_heatmap_plotly(df_diel[Key.formula], log=True, colorscale="viridis")
title = "<b>Elements in Matbench Dielectric</b>"
fig.layout.title = dict(text=title, x=0.4, y=0.94, font_size=20)
# save_fig(fig, "dielectric-ptable-heatmap-plotly.pdf")
# pmv.save_fig(fig, "dielectric-ptable-heatmap-plotly.pdf")


# %%
ax = spacegroup_bar(df_diel[Key.spg_num])
ax = pmv.spacegroup_bar(df_diel[Key.spg_num])
ax.set_title("Space group histogram", y=1.1)
save_fig(ax, "dielectric-spacegroup-hist.pdf")
pmv.save_fig(ax, "dielectric-spacegroup-hist.pdf")


# %%
fig = spacegroup_sunburst(df_diel[Key.spg_num], show_counts="percent")
fig = pmv.spacegroup_sunburst(df_diel[Key.spg_num], show_counts="percent")
fig.layout.title = "Space group sunburst"
# save_fig(fig, "dielectric-spacegroup-sunburst.pdf")
# pmv.save_fig(fig, "dielectric-spacegroup-sunburst.pdf")
fig.show()


Expand All @@ -95,7 +88,7 @@

x_ticks = {} # custom x axis tick labels
for cry_sys, df_group in sorted(
df_diel.groupby(Key.crystal_system), key=lambda x: crystal_sys_order.index(x[0])
df_diel.groupby(Key.crystal_system), key=lambda x: pmv.crystal_sys_order.index(x[0])
):
x_ticks[cry_sys] = (
f"<b>{cry_sys}</b><br>"
Expand All @@ -108,11 +101,11 @@
fig.layout.margin = dict(b=10, l=10, r=10, t=50)
fig.layout.showlegend = False
fig.layout.xaxis = reusable_x_axis = dict(
tickvals=list(range(len(crystal_sys_order))), ticktext=list(x_ticks.values())
tickvals=list(range(len(pmv.crystal_sys_order))), ticktext=list(x_ticks.values())
)


# save_fig(fig, "dielectric-violin.pdf")
# pmv.save_fig(fig, "dielectric-violin.pdf")
fig.show()


Expand All @@ -125,7 +118,7 @@
points="all",
hover_data=[Key.spg_num],
hover_name=Key.formula,
category_orders={Key.crystal_system: crystal_sys_order},
category_orders={Key.crystal_system: pmv.crystal_sys_order},
log_y=True,
).update_traces(jitter=1)

Expand All @@ -137,7 +130,7 @@ def rgb_color(val: float, max_val: float) -> str:

x_ticks = {}
for cry_sys, df_group in sorted(
df_diel.groupby(Key.crystal_system), key=lambda x: crystal_sys_order.index(x[0])
df_diel.groupby(Key.crystal_system), key=lambda x: pmv.crystal_sys_order.index(x[0])
):
n_wyckoff = df_group[Key.n_wyckoff].mean()
clr = rgb_color(n_wyckoff, 14)
Expand All @@ -153,7 +146,7 @@ def rgb_color(val: float, max_val: float) -> str:
fig.layout.showlegend = False
fig.layout.update(width=1000, height=400, xaxis=reusable_x_axis)

# save_fig(fig, "dielectric-violin-num-wyckoffs.pdf")
# pmv.save_fig(fig, "dielectric-violin-num-wyckoffs.pdf")
fig.show()


Expand All @@ -175,5 +168,5 @@ def rgb_color(val: float, max_val: float) -> str:
# slightly increase scatter point size (lower sizeref means larger)
fig.update_traces(marker_sizeref=0.08, selector=dict(mode="markers"))

# save_fig(fig, "dielectric-scatter.pdf")
# pmv.save_fig(fig, "dielectric-scatter.pdf")
fig.show()
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from matminer.datasets import load_dataset
from pymatgen.core import Composition

import pymatviz as pmv
from pymatviz import count_elements, ptable_heatmap
from pymatviz.enums import Key
from pymatviz.io import save_fig


"""Stats for the matbench_expt_gap dataset.
Expand Down Expand Up @@ -62,7 +62,7 @@ def mean_atomic_prop(comp: Composition, prop: str) -> float | None:
return_type="figure",
)
fig.suptitle("Elements in Matbench experimental band gap dataset")
save_fig(fig, "expt-gap-ptable-heatmap.pdf")
pmv.save_fig(fig, "expt-gap-ptable-heatmap.pdf")


# %%
Expand Down
24 changes: 12 additions & 12 deletions examples/dataset_exploration/matbench/jdft2d/explore_jdft2d.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,3 @@
# %%
from matminer.datasets import load_dataset
from tqdm import tqdm

from pymatviz import count_elements, ptable_heatmap, spacegroup_bar, spacegroup_sunburst
from pymatviz.enums import Key
from pymatviz.io import save_fig


"""Stats for the matbench_jdft2d dataset.
Input: Pymatgen Structure of the material.
Expand All @@ -20,6 +11,15 @@
https://ml.materialsproject.org/projects/matbench_jdft2d
"""

# %%
from matminer.datasets import load_dataset
from tqdm import tqdm

import pymatviz as pmv
from pymatviz import count_elements, ptable_heatmap, spacegroup_bar, spacegroup_sunburst
from pymatviz.enums import Key


# %%
df_2d = load_dataset("matbench_jdft2d")

Expand All @@ -32,7 +32,7 @@

# %%
ax = df_2d.hist(column="exfoliation_en", bins=50, log=True)
save_fig(ax, "jdft2d-exfoliation-energy-hist.pdf")
pmv.save_fig(ax, "jdft2d-exfoliation-energy-hist.pdf")


# %%
Expand All @@ -41,12 +41,12 @@

fig = ptable_heatmap(count_elements(df_2d[Key.formula]), log=True, return_type="figure")
fig.suptitle("Elemental prevalence in the Matbench Jarvis DFT 2D dataset")
save_fig(fig, "jdft2d-ptable-heatmap.pdf")
pmv.save_fig(fig, "jdft2d-ptable-heatmap.pdf")


# %%
ax = spacegroup_bar(df_2d[Key.spg_num], log=True)
save_fig(ax, "jdft2d-spacegroup-hist.pdf")
pmv.save_fig(ax, "jdft2d-spacegroup-hist.pdf")


# %%
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from pymatgen.core import Structure
from tqdm import tqdm

import pymatviz as pmv
from pymatviz import (
count_elements,
crystal_sys_order,
Expand All @@ -30,8 +31,6 @@
spacegroup_sunburst,
)
from pymatviz.enums import Key
from pymatviz.io import save_fig
from pymatviz.utils import crystal_sys_from_spg_num


# %%
Expand All @@ -45,9 +44,10 @@
df_grvh[Key.structure], desc="Getting matbench_log_gvrh spacegroups"
)
]
df_grvh[Key.crystal_system] = [
crystal_sys_from_spg_num(x) for x in df_grvh[Key.spg_num]
]
df_grvh[Key.crystal_system] = df_grvh[Key.spg_num].map(
pmv.utils.crystal_sys_from_spg_num
)


df_grvh[Key.wyckoff] = [
get_protostructure_label_from_spglib(struct)
Expand All @@ -73,15 +73,15 @@
ax = df_kvrh.hist(column="log10(K_VRH)", bins=50, alpha=0.8)

df_grvh.hist(column="log10(G_VRH)", bins=50, ax=ax, alpha=0.8)
save_fig(ax, "log_g+kvrh-target-hist.pdf")
pmv.save_fig(ax, "log_g+kvrh-target-hist.pdf")


# %%
df_grvh[Key.volume] = [x.volume for x in df_grvh[Key.structure]]
df_grvh[Key.formula] = [x.formula for x in df_grvh[Key.structure]]

ax = df_grvh.hist(column=Key.volume, bins=50, log=True, alpha=0.8)
save_fig(ax, "log_gvrh-volume-hist.pdf")
pmv.save_fig(ax, "log_gvrh-volume-hist.pdf")


# %%
Expand Down Expand Up @@ -136,12 +136,12 @@ def has_isolated_atom(crystal: Structure, radius: float = 5) -> bool:
count_elements(df_grvh[Key.formula]), log=True, return_type="figure"
)
fig.suptitle("Elemental prevalence in the Matbench bulk/shear modulus datasets")
save_fig(fig, "log_gvrh-ptable-heatmap.pdf")
pmv.save_fig(fig, "log_gvrh-ptable-heatmap.pdf")


# %%
ax = spacegroup_bar(df_grvh[Key.spg_num])
save_fig(ax, "log_gvrh-spacegroup-hist.pdf")
pmv.save_fig(ax, "log_gvrh-spacegroup-hist.pdf")


# %%
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# %%
from matminer.datasets import load_dataset

import pymatviz as pmv
from pymatviz import count_elements, ptable_heatmap
from pymatviz.enums import Key
from pymatviz.io import save_fig


"""Stats for the matbench_mp_e_form dataset.
Expand All @@ -25,7 +25,7 @@

# %%
ax = df_e_form.hist(column="e_form", bins=50, log=True)
save_fig(ax, "mp_e_form_hist.pdf")
pmv.save_fig(ax, "mp_e_form_hist.pdf")


# %%
Expand All @@ -38,4 +38,4 @@
return_type="figure",
)
fig.suptitle("Elemental prevalence in the Matbench formation energy dataset")
save_fig(fig, "mp_e_form-ptable-heatmap.pdf")
pmv.save_fig(fig, "mp_e_form-ptable-heatmap.pdf")
Loading

0 comments on commit f4f8a30

Please sign in to comment.