Skip to content

Commit

Permalink
fix histogram and update example
Browse files Browse the repository at this point in the history
  • Loading branch information
AlishaAng committed Nov 5, 2024
1 parent 2c35d56 commit cbdf320
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 98 deletions.
6 changes: 3 additions & 3 deletions bursty_dynamics/scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ def find_interval_times(events):
if events.dtype.type != np.datetime64:
raise ValueError("Input must be datetime64 elements.")

# Ensure there are at least two events to find intervals
if len(events) < 2:
raise ValueError("At least two datetime events are required to find intervals.")
# # Ensure there are at least two events to find intervals. Unnecessary noise if there are a lot of subjects with less than 2 events.
# if len(events) < 2:
# raise ValueError("At least two datetime events are required to find intervals.")

# Calculate the intervals in minutes
iet = (np.diff(events).astype('timedelta64[ns]').astype(int) / (10**9 * 60)).tolist()
Expand Down
24 changes: 17 additions & 7 deletions bursty_dynamics/visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ def histogram(df, hist=True, set_axis=False, hue=None, **kwargs):
if hist in ['BP', 'MC']:
color = hist_options[hist]
fig = plt.figure()
sns.histplot(data=df, x=hist, kde=True, hue=hue, palette=palette, color = color, **kwargs)
plot = sns.histplot(data=df, x=hist, kde=True, hue=hue, palette=palette, color = color, **kwargs)
sns.move_legend(plot, "upper left", bbox_to_anchor=(1, 1))
if set_axis:
plt.xlim(-1, 1)
plt.close(fig)
Expand All @@ -132,19 +133,25 @@ def histogram(df, hist=True, set_axis=False, hue=None, **kwargs):
return fig

elif hist is True:
fig, axs = plt.subplots(1, len(hist_options), figsize=(12, 6)) # Create subplots based on number of hist_options
fig, axs = plt.subplots(2, 1, figsize=(7, 9)) # Create subplots based on number of hist_options
for i, (column, color) in enumerate({'BP': 'blue', 'MC': 'magenta'}.items()):
sns.histplot(data=df, x=column, hue=hue, kde=True, palette=palette, color=color, ax=axs[i], **kwargs)
axs[i].set_xlabel(column)

if set_axis:
axs[i].set_xlim(-1, 1)

if hue is not None:
sns.move_legend(axs[0], "upper right", bbox_to_anchor=(1.35, 1))
sns.move_legend(axs[1], "upper right", bbox_to_anchor=(1.35, 1))

# plt.tight_layout()
plt.close(fig)
return fig

else:
print("Invalid 'hist' parameter. Please choose from 'BP', 'MC', 'Both', or True.")



def scatterplot(df, hue=None, set_axis=False, **kwargs):
"""
Expand Down Expand Up @@ -192,8 +199,8 @@ def scatterplot(df, hue=None, set_axis=False, **kwargs):
palette="tab10",
joint_kws=dict(s=50, alpha=0.4, edgecolor=None),
**kwargs)
plot.ax_joint.legend(loc='upper right', bbox_to_anchor=(1.25, 1.22))
plt.subplots_adjust(right=0.8)
plt.legend(bbox_to_anchor=(1.25, 1), loc='upper left', title=hue)
plt.subplots_adjust(right=1)


else:
Expand Down Expand Up @@ -315,11 +322,14 @@ def event_counts(train_info_df, x_limit=30, hue=None, **kwargs):
sns.countplot(data=train_info_df[train_info_df["unique_event_counts"] <= x_limit],
x="unique_event_counts", hue=hue, palette=palette, ax=ax, **kwargs)

# Set axis labels
ax.set_xlabel('Number of Events per Train', fontsize=14)
ax.set_ylabel('Number of Trains', fontsize=14)
# Adjust tick parameters
ax.tick_params(axis='both', which='major', labelsize=14)
ax.set_xticks(ax.get_xticks())
ax.set_xticklabels(ax.get_xticks(), rotation=45, ha='right')

plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
plt.tight_layout()

plt.close(fig) # Prevents the plot from displaying in interactive environments

Expand Down
182 changes: 95 additions & 87 deletions example/examples.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name= 'bursty_dynamics',
version = '0.1.4',
version = '0.1.5',
description = DESCRIPTION,
packages = find_packages(),
long_description = open('README.rst').read(),
Expand Down

0 comments on commit cbdf320

Please sign in to comment.