Skip to content

Commit

Permalink
Merge pull request #586 from py-econometrics/mpl_patch
Browse files Browse the repository at this point in the history
overhaul mpl backend; fix coordinate flip argument
  • Loading branch information
apoorvalal authored Aug 26, 2024
2 parents ed5553d + 43ae9ab commit b66f517
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 39 deletions.
127 changes: 89 additions & 38 deletions pyfixest/report/visualize.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Optional, Union

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn.objects as so
from lets_plot import (
LetsPlot,
aes,
Expand Down Expand Up @@ -73,6 +73,7 @@ def iplot(
exact_match: bool = False,
plot_backend: str = "lets_plot",
labels: Optional[dict] = None,
ax: Optional[plt.Axes] = None,
joint: Optional[Union[str, bool]] = None,
seed: Optional[int] = None,
):
Expand Down Expand Up @@ -199,6 +200,7 @@ def iplot(
title=title,
flip_coord=coord_flip,
labels=labels,
ax=ax,
)


Expand All @@ -218,6 +220,7 @@ def coefplot(
labels: Optional[dict] = None,
joint: Optional[Union[str, bool]] = None,
seed: Optional[int] = None,
ax: Optional[plt.Axes] = None,
):
r"""
Plot model coefficients with confidence intervals.
Expand Down Expand Up @@ -329,6 +332,7 @@ def coefplot(
title=title,
flip_coord=coord_flip,
labels=labels,
ax=ax,
)


Expand All @@ -353,6 +357,7 @@ def _coefplot_lets_plot(
title: Optional[str] = None,
flip_coord: Optional[bool] = True,
labels: Optional[dict] = None,
ax=None, # for compatibility with matplotlib backend
):
"""
Plot model coefficients with confidence intervals.
Expand Down Expand Up @@ -430,17 +435,18 @@ def _coefplot_matplotlib(
title: Optional[str] = None,
flip_coord: Optional[bool] = True,
labels: Optional[dict] = None,
ax: Optional[plt.Axes] = None,
dodge: float = 0.5,
**fig_kwargs,
) -> so.Plot:
) -> plt.Figure:
"""
Plot model coefficients with confidence intervals.
We use the seaborn library to create the plot through the seaborn objects interface.
Plot model coefficients with confidence intervals, supporting multiple models.
Parameters
----------
df pandas.DataFrame
df : pandas.DataFrame
The dataframe containing the data used for the model fitting.
Must include a 'fml' column identifying different models.
figsize : tuple
The size of the figure.
alpha : float
Expand All @@ -457,58 +463,103 @@ def _coefplot_matplotlib(
Whether to flip the coordinates of the plot. Default is True.
labels : dict, optional
A dictionary to relabel the variables. The keys are the original variable names and the values the new names.
dodge : float, optional
The amount to dodge each model's points by. Default is 0.1.
fig_kwargs : dict
Additional keyword arguments to pass to the matplotlib figure.
Returns
-------
object
A seaborn Plot object.
See Also
--------
- https://seaborn.pydata.org/tutorial/objects_interface.html
matplotlib.figure.Figure
A matplotlib Figure object.
"""
if labels is not None:
interactionSymbol = " x "
df["Coefficient"] = df["Coefficient"].apply(
lambda x: _relabel_expvar(x, labels, interactionSymbol)
)

ub, lb = (f"{round(x * 100, 1)}%" for x in [1 - alpha / 2, alpha / 2])

yintercept = yintercept if yintercept is not None else 0
ub, lb = alpha / 2, 1 - alpha / 2
title = title if title is not None else "Coefficient Plot"

_, ax = plt.subplots(figsize=figsize, **fig_kwargs)
if ax is None:
f, ax = plt.subplots(figsize=figsize, **fig_kwargs)
else:
f = ax.get_figure()

ax.axvline(x=yintercept, color="black", linestyle="--")
# Check if we have multiple models
models = df["fml"].unique()
is_multi_model = len(models) > 1

if xintercept is not None:
ax.axhline(y=xintercept, color="black", linestyle="--")
colors = plt.cm.jet(np.linspace(0, 1, len(models)))
color_dict = dict(zip(models, colors))

plot = (
so.Plot(df, x="Estimate", y="Coefficient", color="fml")
.add(so.Dot(), so.Dodge(empty="drop"))
.add(
so.Range(),
so.Dodge(empty="drop"),
xmin=str(round(lb * 100, 1)) + "%",
xmax=str(round(ub * 100, 1)) + "%",
)
.label(
title=title,
x=rf"Estimate and {round((1-alpha)*100, 1)}% Confidence Interval",
y="Coefficient",
color="Model",
)
.on(ax)
)
ax.tick_params(axis="x", rotation=rotate_xticks)
# Calculate the positions for dodging
unique_coefficients = df["Coefficient"].unique()
if is_multi_model:
coef_positions = {coef: i for i, coef in enumerate(unique_coefficients)}
dodge_start = -(len(models) - 1) * dodge / 2

if flip_coord:
ax.invert_yaxis()
for i, (model, group) in enumerate(df.groupby("fml")):
color = color_dict[model]

if is_multi_model:
dodge_val = dodge_start + i * dodge
x_pos = [coef_positions[coef] + dodge_val for coef in group["Coefficient"]]
else:
x_pos = list(map(float, range(len(group))))

err = [group["Estimate"] - group[lb], group[ub] - group["Estimate"]]

if flip_coord:
ax.errorbar(
x=group["Estimate"],
y=x_pos,
xerr=err,
fmt="o",
capsize=5,
color=color,
label=model if is_multi_model else "Estimates",
)
else:
ax.errorbar(
y=group["Estimate"],
x=x_pos,
# yerr=group["Std. Error"] * critval,
yerr=err,
fmt="o",
capsize=5,
color=color,
label=model if is_multi_model else "Estimates",
)

if flip_coord:
ax.axvline(x=yintercept, color="black", linestyle="--")
if xintercept is not None:
ax.axhline(y=xintercept, color="black", linestyle="--")
ax.set_xlabel(rf"Estimate and {round((1-alpha)*100, 1)}% Confidence Interval")
ax.set_ylabel("Coefficient")
ax.set_yticks(range(len(unique_coefficients)))
ax.set_yticklabels(unique_coefficients)
ax.tick_params(axis="y", rotation=rotate_xticks)
else:
ax.axhline(y=yintercept, color="black", linestyle="--")
if xintercept is not None:
ax.axvline(x=xintercept, color="black", linestyle="--")
ax.set_ylabel(rf"Estimate and {round((1-alpha)*100, 1)}% Confidence Interval")
ax.set_xlabel("Coefficient")
ax.set_xticks(range(len(unique_coefficients)))
ax.set_xticklabels(unique_coefficients)
ax.tick_params(axis="x", rotation=rotate_xticks)

ax.set_title(title)
if is_multi_model:
ax.legend()
plt.tight_layout()
plt.close()
return plot
return f


def _get_model_df(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_visualize.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pandas as pd
import matplotlib.pyplot as plt
import pandas as pd

from pyfixest.did.visualize import panelview
from pyfixest.estimation.estimation import feols, fepois
Expand Down

0 comments on commit b66f517

Please sign in to comment.