Skip to content

Commit

Permalink
Added test for datetime input on jointplot, wrote code to pass test (m…
Browse files Browse the repository at this point in the history
  • Loading branch information
athompson1991 committed Aug 4, 2024
1 parent b4e5f8d commit 5a098e7
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 0 deletions.
12 changes: 12 additions & 0 deletions seaborn/axisgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -1704,6 +1704,7 @@ def __init__(
self.ax_joint = ax_joint
self.ax_marg_x = ax_marg_x
self.ax_marg_y = ax_marg_y
self.is_datetime_x = False

# Turn off tick visibility for the measure axis on the marginal plots
plt.setp(ax_marg_x.get_xticklabels(), visible=False)
Expand Down Expand Up @@ -1736,6 +1737,9 @@ def get_var(var):
vector = plot_data.get(var, None)
if vector is not None:
vector = vector.rename(p.variables.get(var, None))
if np.issubdtype(vector, np.datetime64):
vector = vector.astype(int)
self.is_datetime_x = True
return vector

self.x = get_var("x")
Expand Down Expand Up @@ -1833,6 +1837,14 @@ def plot_joint(self, func, **kwargs):
else:
func(self.x, self.y, **kwargs)

if self.is_datetime_x:
xtick_arr = self.ax_joint.get_xticks()
date_label = xtick_arr.astype("datetime64[ns]")
date_label = [str(dt)[0:10] for dt in date_label]
self.ax_joint.set_xticks(self.ax_joint.get_xticks())
self.ax_joint.set_xticklabels(date_label, rotation=45)
self._figure.tight_layout()

return self

def plot_marginals(self, func, **kwargs):
Expand Down
2 changes: 2 additions & 0 deletions seaborn/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2376,6 +2376,8 @@ def _freedman_diaconis_bins(a):
"""Calculate number of hist bins using Freedman-Diaconis rule."""
# From https://stats.stackexchange.com/questions/798/
a = np.asarray(a)
if a.dtype.type == np.datetime64:
a = a.astype(np.int64)
if len(a) < 2:
return 1
iqr = np.subtract.reduce(np.nanpercentile(a, [75, 25]))
Expand Down
6 changes: 6 additions & 0 deletions tests/test_axisgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -1878,3 +1878,9 @@ def test_ax_warning(self, long_df):
with pytest.warns(UserWarning):
g = ag.jointplot(data=long_df, x="x", y="y", ax=ax)
assert g.ax_joint.collections

def test_datetime_input(self):
dates = np.array(
["2023-01-01", "2023-01-02", "2023-01-03"], dtype="datetime64[ns]"
)
ag.jointplot(x=dates, y=[1, 2, 3], kind="hex")

0 comments on commit 5a098e7

Please sign in to comment.