Skip to content

Commit

Permalink
Improve scatter plot to support diploid data + allow labels_file to c…
Browse files Browse the repository at this point in the history
…ontain labels not in dimredobj
  • Loading branch information
miriambt committed Jan 3, 2025
1 parent b7e697e commit d489b30
Showing 1 changed file with 17 additions and 12 deletions.
29 changes: 17 additions & 12 deletions snputils/visualization/scatter_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def scatter(
Args:
dimredobj (np.ndarray):
Reduced dimensionality data; expected to have n_samples x 2 shape.
Reduced dimensionality data; expected to have `(n_haplotypes, 2)` shape.
labels_file (str):
Path to a TSV file with columns 'indID' and 'label', providing labels for coloring and annotating points.
abbreviation_inside_dots (bool):
Expand All @@ -41,29 +41,34 @@ def scatter(
Returns:
None
"""
# Load data from the dimension-reduced object and labels file
data = dimredobj.X_new_ # 2D data points for plotting
labels_df = pd.read_csv(labels_file, sep='\t') # Load labels from TSV
# Load labels from TSV
labels_df = pd.read_csv(labels_file, sep='\t')

# Initialize the plot
fig, ax = plt.subplots(figsize=(10, 8))
# Filter labels based on the indIDs in dimredobj
sample_ids = dimredobj.samples_
filtered_labels_df = labels_df[labels_df['indID'].isin(sample_ids)]

# Define unique colors for each group label, either from color_palette or defaulting to 'tab10'
unique_labels = labels_df['label'].unique()
unique_labels = filtered_labels_df['label'].unique()
colors = color_palette if color_palette else cm.get_cmap('tab10', len(unique_labels))

# Initialize the plot
fig, ax = plt.subplots(figsize=(10, 8))

# Calculate the overall center of the plot (used for positioning arrows)
plot_center = data.mean(axis=0)
plot_center = dimredobj.X_new_.mean(axis=0)

# Dictionary to hold centroid positions for each label
centroids = {}

# Iterate through each unique label to plot points and centroids
# Plot data points and centroids by label
for i, label in enumerate(unique_labels):
# Filter points corresponding to the current label
indices = labels_df[labels_df['label'] == label].index
points = data[indices]
# Get sample IDs corresponding to the current label
sample_ids_for_label = filtered_labels_df[filtered_labels_df['label'] == label]['indID']

# Filter points based on sample IDs
points = dimredobj.X_new_[np.isin(dimredobj.samples_, sample_ids_for_label)]

if dots:
# Plot individual points for the current group
ax.scatter(points[:, 0], points[:, 1], s=30, color=colors(i), alpha=0.6, label=label)
Expand Down

0 comments on commit d489b30

Please sign in to comment.