Skip to content

Commit

Permalink
breaking: rename <xyz>_kwds -> <xyz>_kwargs for consistency
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Sep 5, 2024
1 parent 571f251 commit 24261ca
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 42 deletions.
2 changes: 1 addition & 1 deletion examples/make_assets/uncertainty.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

# %% Uncertainty Plots
ax = pmv.qq_gaussian(
y_pred, y_true, y_std, identity_line={"line_kwds": {"color": "red"}}
y_pred, y_true, y_std, identity_line={"line_kwargs": {"color": "red"}}
)
pmv.io.save_and_compress_svg(ax, "normal-prob-plot")

Expand Down
24 changes: 13 additions & 11 deletions examples/mlff_phonons.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -52,39 +52,41 @@
"for directory in (RUNS_DIR,):\n",
" os.makedirs(directory, exist_ok=True)\n",
"\n",
"common_relax_kwds = dict(fmax=0.00001)\n",
"mace_kwds = dict(model=\"medium\")\n",
"chgnet_kwds = dict(optimizer_kwargs=dict(use_device=\"mps\"), assign_magmoms=False)\n",
"common_relax_kwargs = dict(fmax=0.00001)\n",
"mace_kwargs = dict(model=\"medium\")\n",
"chgnet_kwargs = dict(optimizer_kwargs=dict(use_device=\"mps\"), assign_magmoms=False)\n",
"\n",
"do_mlff_relax = True # whether to MLFF-relax the PBE structure\n",
"models = {\n",
" str(Model.mace_mp): dict(\n",
" bulk_relax_maker=ff_jobs.MACERelaxMaker(\n",
" relax_kwargs=common_relax_kwds,\n",
" calculator_kwargs={\"default_dtype\": \"float64\"} | mace_kwds,\n",
" relax_kwargs=common_relax_kwargs,\n",
" calculator_kwargs={\"default_dtype\": \"float64\"} | mace_kwargs,\n",
" )\n",
" if do_mlff_relax\n",
" else None,\n",
" phonon_displacement_maker=ff_jobs.MACEStaticMaker(calculator_kwargs=mace_kwds),\n",
" static_energy_maker=ff_jobs.MACEStaticMaker(calculator_kwargs=mace_kwds),\n",
" phonon_displacement_maker=ff_jobs.MACEStaticMaker(\n",
" calculator_kwargs=mace_kwargs\n",
" ),\n",
" static_energy_maker=ff_jobs.MACEStaticMaker(calculator_kwargs=mace_kwargs),\n",
" ),\n",
" str(Model.m3gnet_ms): dict(\n",
" bulk_relax_maker=ff_jobs.M3GNetRelaxMaker(relax_kwargs=common_relax_kwds)\n",
" bulk_relax_maker=ff_jobs.M3GNetRelaxMaker(relax_kwargs=common_relax_kwargs)\n",
" if do_mlff_relax\n",
" else None,\n",
" phonon_displacement_maker=ff_jobs.M3GNetStaticMaker(),\n",
" static_energy_maker=ff_jobs.M3GNetStaticMaker(),\n",
" ),\n",
" str(Model.chgnet_030): dict(\n",
" bulk_relax_maker=ff_jobs.CHGNetRelaxMaker(\n",
" relax_kwargs=common_relax_kwds, calculator_kwargs=chgnet_kwds\n",
" relax_kwargs=common_relax_kwargs, calculator_kwargs=chgnet_kwargs\n",
" )\n",
" if do_mlff_relax\n",
" else None,\n",
" phonon_displacement_maker=ff_jobs.CHGNetStaticMaker(\n",
" calculator_kwargs=chgnet_kwds\n",
" calculator_kwargs=chgnet_kwargs\n",
" ),\n",
" static_energy_maker=ff_jobs.CHGNetStaticMaker(calculator_kwargs=chgnet_kwds),\n",
" static_energy_maker=ff_jobs.CHGNetStaticMaker(calculator_kwargs=chgnet_kwargs),\n",
" ),\n",
"}"
]
Expand Down
8 changes: 4 additions & 4 deletions pymatviz/bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,18 +222,18 @@ def spacegroup_bar(
for patch in ax.patches[0 if x0 == 1 else x0 : x1 + 1]:
patch.set_facecolor(color)

text_kwds = dict(transform=transform, horizontalalignment="center") | (
text_kwargs = dict(transform=transform, horizontalalignment="center") | (
text_kwargs or {}
)
crys_sys_anno_kwds = dict(
crys_sys_anno_kwargs = dict(
rotation=90, va="top", ha="right", fontdict={"fontsize": 14}
)
ax.text(*[(x0 + x1) / 2, 0.95], crys_sys, **crys_sys_anno_kwds | text_kwds)
ax.text(*[(x0 + x1) / 2, 0.95], crys_sys, **crys_sys_anno_kwargs | text_kwargs)
if show_counts:
ax.text(
*[(x0 + x1) / 2, 1.02],
f"{si_fmt_int(count)} ({count / len(data):.0%})",
**dict(fontdict={"fontsize": 12}) | text_kwds,
**dict(fontdict={"fontsize": 12}) | text_kwargs,
)

ax.fill_between(
Expand Down
24 changes: 12 additions & 12 deletions pymatviz/powerups/both.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def calculate_metrics(xs: ArrayLike, ys: ArrayLike) -> str:
def add_identity_line(
fig: go.Figure | plt.Figure | plt.Axes,
*,
line_kwds: dict[str, Any] | None = None,
line_kwargs: dict[str, Any] | None = None,
trace_idx: int = 0,
retain_xy_limits: bool = False,
**kwargs: Any,
Expand All @@ -146,8 +146,8 @@ def add_identity_line(
Args:
fig (go.Figure | plt.Figure | plt.Axes): plotly/matplotlib figure or axes to
add the identity line to.
line_kwds (dict[str, Any], optional): Keyword arguments for customizing the line
shape will be passed to fig.add_shape(line=line_kwds). Defaults to
line_kwargs (dict[str, Any], optional): Keyword arguments for customizing the
line shape will be passed to fig.add_shape(line=line_kwargs). Defaults to
dict(color="gray", width=1, dash="dash").
trace_idx (int, optional): Index of the trace to use for measuring x/y limits.
Defaults to 0. Unused if kaleido package is installed and the figure's
Expand All @@ -171,7 +171,7 @@ def add_identity_line(
ax = fig if isinstance(fig, plt.Axes) else fig.gca()

line_defaults = dict(alpha=0.5, zorder=0, linestyle="dashed", color="black")
ax.axline((x_min, x_min), (x_max, x_max), **line_defaults | (line_kwds or {}))
ax.axline((x_min, x_min), (x_max, x_max), **line_defaults | (line_kwargs or {}))
return fig

if isinstance(fig, go.Figure):
Expand All @@ -189,7 +189,7 @@ def add_identity_line(
type="line",
**dict(x0=xy_min_min, y0=xy_min_min, x1=xy_max_min, y1=xy_max_min),
layer="below",
line=line_defaults | (line_kwds or {}),
line=line_defaults | (line_kwargs or {}),
**kwargs,
)
if retain_xy_limits:
Expand All @@ -207,7 +207,7 @@ def add_best_fit_line(
xs: ArrayLike = (),
ys: ArrayLike = (),
trace_idx: int | None = None,
line_kwds: dict[str, Any] | None = None,
line_kwargs: dict[str, Any] | None = None,
annotate_params: bool | dict[str, Any] = True,
warn: bool = True,
**kwargs: Any,
Expand All @@ -223,8 +223,8 @@ def add_best_fit_line(
means use the y-values of trace at trace_idx in fig.
trace_idx (int, optional): Index of the trace to use for measuring x/y values
for fitting if xs and ys are not provided. Defaults to 0.
line_kwds (dict[str, Any], optional): Keyword arguments for customizing the line
shape. For plotly, will be passed to fig.add_shape(line=line_kwds).
line_kwargs (dict[str, Any], optional): Keyword arguments for customizing the
line shape. For plotly, will be passed to fig.add_shape(line=line_kwargs).
For matplotlib, will be passed to ax.plot(). Defaults to None.
annotate_params (dict[str, Any], optional): Pass dict to customize
the annotation of the best fit line. Set to False to disable annotation.
Expand Down Expand Up @@ -302,7 +302,7 @@ def add_best_fit_line(
)

defaults = dict(alpha=0.7, linestyle="--", zorder=1)
ax.axline((x0, y0), (x1, y1), **(defaults | (line_kwds or {})) | kwargs)
ax.axline((x0, y0), (x1, y1), **(defaults | (line_kwargs or {})) | kwargs)

return fig

Expand All @@ -324,8 +324,8 @@ def add_best_fit_line(
x0, x1 = x_min, x_max
y0, y1 = slope * x0 + intercept, slope * x1 + intercept

line_kwds = (
(line_kwds or {})
line_kwargs = (
(line_kwargs or {})
| dict(color=line_color, width=2, dash="dash")
| kwargs.pop("line", {})
)
Expand All @@ -339,7 +339,7 @@ def add_best_fit_line(
y1=y1,
xref=xref,
yref=yref,
line=line_kwds,
line=line_kwargs,
**{k: v for k, v in kwargs.items() if k not in invalid_kwargs},
)

Expand Down
4 changes: 2 additions & 2 deletions pymatviz/structure_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,13 +427,13 @@ class used to plot chemical bonds. Allowed are edgecolor, facecolor, color,
(0.5 * radius) * direction if occupancy < 1 else (0, 0)
)

txt_kwds = dict(
text_kwargs = dict(
ha="center",
va="center",
zorder=zorder,
**(label_kwargs or {}),
)
ax.text(*(xy + text_offset), txt, **txt_kwds)
ax.text(*(xy + text_offset), txt, **text_kwargs)

start += occupancy

Expand Down
26 changes: 14 additions & 12 deletions tests/powerups/test_both_powerups.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def test_add_best_fit_line_invalid_fig() -> None:


def test_add_best_fit_line_custom_line_kwargs(plotly_scatter: go.Figure) -> None:
line_kwds = {"width": 3, "dash": "dot"}
result = pmv.powerups.add_best_fit_line(plotly_scatter, line_kwds=line_kwds)
line_kwargs = {"width": 3, "dash": "dot"}
result = pmv.powerups.add_best_fit_line(plotly_scatter, line_kwargs=line_kwargs)

best_fit_line = result.layout.shapes[-1]
assert best_fit_line.line.width == 2
Expand Down Expand Up @@ -146,7 +146,7 @@ def test_add_best_fit_line_custom_xs_ys(


@pytest.mark.parametrize(
("xaxis_type", "yaxis_type", "trace_idx", "line_kwds", "retain_xy_limits"),
("xaxis_type", "yaxis_type", "trace_idx", "line_kwargs", "retain_xy_limits"),
[
("linear", "log", 0, None, True),
("log", "linear", 1, {"color": "red"}, False),
Expand All @@ -159,7 +159,7 @@ def test_add_identity_line(
xaxis_type: str,
yaxis_type: str,
trace_idx: int,
line_kwds: dict[str, str] | None,
line_kwargs: dict[str, str] | None,
retain_xy_limits: bool,
) -> None:
# Set axis types
Expand All @@ -173,7 +173,7 @@ def test_add_identity_line(

fig = pmv.powerups.add_identity_line(
plotly_scatter,
line_kwds=line_kwds,
line_kwargs=line_kwargs,
trace_idx=trace_idx,
retain_xy_limits=retain_xy_limits,
)
Expand All @@ -184,7 +184,7 @@ def test_add_identity_line(
assert line is not None

assert line.layer == "below"
assert line.line.color == (line_kwds["color"] if line_kwds else "gray")
assert line.line.color == (line_kwargs["color"] if line_kwargs else "gray")
# check line coordinates
assert line.x0 == line.y0
assert line.x1 == line.y1
Expand All @@ -205,17 +205,19 @@ def test_add_identity_line(
assert y_range_post != y_range_pre


@pytest.mark.parametrize("line_kwds", [None, {"color": "blue"}])
@pytest.mark.parametrize("line_kwargs", [None, {"color": "blue"}])
def test_add_identity_matplotlib(
matplotlib_scatter: plt.Figure, line_kwds: dict[str, str] | None
matplotlib_scatter: plt.Figure, line_kwargs: dict[str, str] | None
) -> None:
expected_line_color = (line_kwds or {}).get("color", "black")
expected_line_color = (line_kwargs or {}).get("color", "black")
# test Figure
fig = pmv.powerups.add_identity_line(matplotlib_scatter, line_kwds=line_kwds)
fig = pmv.powerups.add_identity_line(matplotlib_scatter, line_kwargs=line_kwargs)
assert isinstance(fig, plt.Figure)

# test Axes
ax = pmv.powerups.add_identity_line(matplotlib_scatter.axes[0], line_kwds=line_kwds)
ax = pmv.powerups.add_identity_line(
matplotlib_scatter.axes[0], line_kwargs=line_kwargs
)
assert isinstance(ax, plt.Axes)

line = fig.axes[0].lines[-1] # retrieve identity line
Expand All @@ -225,7 +227,7 @@ def test_add_identity_matplotlib(
_fig_log, ax_log = plt.subplots()
ax_log.plot([1, 10, 100], [10, 100, 1000])
ax_log.set(xscale="log", yscale="log")
ax_log = pmv.powerups.add_identity_line(ax, line_kwds=line_kwds)
ax_log = pmv.powerups.add_identity_line(ax, line_kwargs=line_kwargs)

line = fig.axes[0].lines[-1]
assert line.get_color() == expected_line_color
Expand Down

0 comments on commit 24261ca

Please sign in to comment.