-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* nexnt loss prototype * fix bug with z_scale_range in hcs datamodule. If the value is an int this does not work. * exclude the negative pair from dataloader and forward pass * adding option using pytorch-metric-learning implementation and modifying previous to match same input args * removing our implementation of NTXentLoss and using pytorch metric * ruff * remove blank line diff * remove blank line diff * simplying the engine * PHATE (#210) * translation: fix validation loss aggregation (#202) * exposing prefetch and persistent worker (#203) * metrics for dynamic, smoothness and docstrings * updated metrics and plots for distance * fixed CI test cases * nexnt loss prototype * fix bug with z_scale_range in hcs datamodule. If the value is an int this does not work. * exclude the negative pair from dataloader and forward pass * adding option using pytorch-metric-learning implementation and modifying previous to match same input args * removing our implementation of NTXentLoss and using pytorch metric * ruff * prototype for phate and umap plot * - proofreading the calculations - removing unecessary calls to ALFI script - simplifying code to re-use functions * methods to rank nearest neighbors in embeddings * example script to plot state change of a single track * test using scaled features * phate embeddings * removing dataframe from the compute_phate adding docstring * adding phate to the prediction writer and moving it as dependency. * changing the phate defaults in the prediction writer. * ruff * fixing bug in phate in predict writer * adding code for measuring the smoothness * cleanup to run on triplet and ntxent * fix plots for smoothnes * nexnt loss prototype * exclude the negative pair from dataloader and forward pass * adding option using pytorch-metric-learning implementation and modifying previous to match same input args * removing our implementation of NTXentLoss and using pytorch metric * ruff * remove blank line diff * remove blank line diff * simplying the engine * explicit target shape argument in the HCS data module * Revert "explicit target shape argument in the HCS data module" This reverts commit 464d4c9. * Explicit target shape argument in the HCS data module (#212) * explicit target shape argument in the HCS data module * update docstring * update test cases * Gradio example (#158) * initial demo * using the predict_step * modifying paths to chkpt and example pngs * updating gradio as the one on Huggingface * adding configurable phate arguments via config * script to recompute phate and overwrite the previous phate data * ruff * solving redundancies * modularizing the smoothness * removing redundant _fit_phate() * ruff --------- Co-authored-by: Ziwen Liu <[email protected]> Co-authored-by: Alishba Imran <[email protected]> Co-authored-by: Ziwen Liu <[email protected]> * renaming cross_dissimilairy with pairwaise_distance_matrix --------- Co-authored-by: Ziwen Liu <[email protected]> Co-authored-by: Alishba Imran <[email protected]> Co-authored-by: Ziwen Liu <[email protected]>
- Loading branch information
1 parent
e88a174
commit 316deee
Showing
13 changed files
with
942 additions
and
217 deletions.
There are no files selected for viewing
312 changes: 312 additions & 0 deletions
312
applications/contrastive_phenotyping/evaluation/ALFI_displacement.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,312 @@ | ||
# %% | ||
from pathlib import Path | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
from viscy.representation.embedding_writer import read_embedding_dataset | ||
from viscy.representation.evaluation.distance import ( | ||
calculate_normalized_euclidean_distance_cell, | ||
compute_displacement, | ||
compute_dynamic_range, | ||
compute_rms_per_track, | ||
) | ||
from collections import defaultdict | ||
from tabulate import tabulate | ||
|
||
import numpy as np | ||
from sklearn.metrics.pairwise import cosine_similarity | ||
from collections import OrderedDict | ||
|
||
# %% function | ||
|
||
# Removed redundant compute_displacement_mean_std_full function | ||
# Removed redundant compute_dynamic_range and compute_rms_per_track functions | ||
|
||
|
||
def plot_rms_histogram(rms_values, label, bins=30): | ||
""" | ||
Plot histogram of RMS values across tracks. | ||
Parameters: | ||
rms_values : list | ||
List of RMS values, one for each track. | ||
label : str | ||
Label for the dataset (used in the title). | ||
bins : int, optional | ||
Number of bins for the histogram. Default is 30. | ||
Returns: | ||
None: Displays the histogram. | ||
""" | ||
plt.figure(figsize=(10, 6)) | ||
plt.hist(rms_values, bins=bins, alpha=0.7, color="blue", edgecolor="black") | ||
plt.title(f"Histogram of RMS Values Across Tracks ({label})", fontsize=16) | ||
plt.xlabel("RMS of Time Derivative", fontsize=14) | ||
plt.ylabel("Frequency", fontsize=14) | ||
plt.grid(True) | ||
plt.show() | ||
|
||
|
||
def plot_displacement( | ||
mean_displacement, std_displacement, label, metrics_no_track=None | ||
): | ||
""" | ||
Plot embedding displacement over time with mean and standard deviation. | ||
Parameters: | ||
mean_displacement : dict | ||
Mean displacement for each tau. | ||
std_displacement : dict | ||
Standard deviation of displacement for each tau. | ||
label : str | ||
Label for the dataset. | ||
metrics_no_track : dict, optional | ||
Metrics for the "Classical Contrastive (No Tracking)" dataset to compare against. | ||
Returns: | ||
None: Displays the plot. | ||
""" | ||
plt.figure(figsize=(10, 6)) | ||
taus = list(mean_displacement.keys()) | ||
mean_values = list(mean_displacement.values()) | ||
std_values = list(std_displacement.values()) | ||
|
||
plt.plot(taus, mean_values, marker="o", label=f"{label}", color="green") | ||
plt.fill_between( | ||
taus, | ||
np.array(mean_values) - np.array(std_values), | ||
np.array(mean_values) + np.array(std_values), | ||
color="green", | ||
alpha=0.3, | ||
label=f"Std Dev ({label})", | ||
) | ||
|
||
if metrics_no_track: | ||
mean_values_no_track = list(metrics_no_track["mean_displacement"].values()) | ||
std_values_no_track = list(metrics_no_track["std_displacement"].values()) | ||
|
||
plt.plot( | ||
taus, | ||
mean_values_no_track, | ||
marker="o", | ||
label="Classical Contrastive (No Tracking)", | ||
color="blue", | ||
) | ||
plt.fill_between( | ||
taus, | ||
np.array(mean_values_no_track) - np.array(std_values_no_track), | ||
np.array(mean_values_no_track) + np.array(std_values_no_track), | ||
color="blue", | ||
alpha=0.3, | ||
label="Std Dev (No Tracking)", | ||
) | ||
|
||
plt.xlabel("Time Shift (τ)", fontsize=14) | ||
plt.ylabel("Euclidean Distance", fontsize=14) | ||
plt.title(f"Embedding Displacement Over Time ({label})", fontsize=16) | ||
plt.grid(True) | ||
plt.legend(fontsize=12) | ||
plt.show() | ||
|
||
|
||
def plot_overlay_displacement(overlay_displacement_data): | ||
""" | ||
Plot embedding displacement over time for all datasets in one plot. | ||
Parameters: | ||
overlay_displacement_data : dict | ||
A dictionary containing mean displacement per tau for all datasets. | ||
Returns: | ||
None: Displays the plot. | ||
""" | ||
plt.figure(figsize=(12, 8)) | ||
for label, mean_displacement in overlay_displacement_data.items(): | ||
taus = list(mean_displacement.keys()) | ||
mean_values = list(mean_displacement.values()) | ||
plt.plot(taus, mean_values, marker="o", label=label) | ||
|
||
plt.xlabel("Time Shift (τ)", fontsize=14) | ||
plt.ylabel("Euclidean Distance", fontsize=14) | ||
plt.title("Overlayed Embedding Displacement Over Time", fontsize=16) | ||
plt.grid(True) | ||
plt.legend(fontsize=12) | ||
plt.show() | ||
|
||
|
||
# %% hist stats | ||
def plot_boxplot_rms_across_models(datasets_rms): | ||
""" | ||
Plot a boxplot for the distribution of RMS values across models. | ||
Parameters: | ||
datasets_rms : dict | ||
A dictionary where keys are dataset names and values are lists of RMS values. | ||
Returns: | ||
None: Displays the boxplot. | ||
""" | ||
plt.figure(figsize=(12, 6)) | ||
labels = list(datasets_rms.keys()) | ||
data = list(datasets_rms.values()) | ||
print(labels) | ||
print(data) | ||
# Plot the boxplot | ||
plt.boxplot(data, tick_labels=labels, patch_artist=True, showmeans=True) | ||
|
||
plt.title( | ||
"Distribution of RMS of Rate of Change of Embedding Across Models", fontsize=16 | ||
) | ||
plt.ylabel("RMS of Time Derivative", fontsize=14) | ||
plt.xticks(rotation=45, fontsize=12) | ||
plt.grid(axis="y", linestyle="--", alpha=0.7) | ||
plt.tight_layout() | ||
plt.show() | ||
|
||
|
||
def plot_histogram_absolute_differences(datasets_abs_diff): | ||
""" | ||
Plot histograms of absolute differences across embeddings for all models. | ||
Parameters: | ||
datasets_abs_diff : dict | ||
A dictionary where keys are dataset names and values are lists of absolute differences. | ||
Returns: | ||
None: Displays the histograms. | ||
""" | ||
plt.figure(figsize=(12, 6)) | ||
for label, abs_diff in datasets_abs_diff.items(): | ||
plt.hist(abs_diff, bins=50, alpha=0.5, label=label, density=True) | ||
|
||
plt.title("Histograms of Absolute Differences Across Models", fontsize=16) | ||
plt.xlabel("Absolute Difference", fontsize=14) | ||
plt.ylabel("Density", fontsize=14) | ||
plt.legend(fontsize=12) | ||
plt.grid(alpha=0.7) | ||
plt.tight_layout() | ||
plt.show() | ||
|
||
|
||
# %% Paths to datasets | ||
feature_paths = { | ||
"7 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_7mins.zarr", | ||
"21 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_21mins.zarr", | ||
"28 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_updated_28mins.zarr", | ||
"56 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_56mins.zarr", | ||
"Cell Aware": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_cellaware.zarr", | ||
} | ||
|
||
no_track_path = "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_classical.zarr" | ||
|
||
# %% Process Datasets | ||
max_tau = 69 | ||
metrics = {} | ||
|
||
overlay_displacement_data = {} | ||
datasets_rms = {} | ||
datasets_abs_diff = {} | ||
|
||
# Process "No Tracking" dataset | ||
features_path_no_track = Path(no_track_path) | ||
embedding_dataset_no_track = read_embedding_dataset(features_path_no_track) | ||
|
||
mean_displacement_no_track, std_displacement_no_track = compute_displacement( | ||
embedding_dataset_no_track, max_tau=max_tau, return_mean_std=True | ||
) | ||
dynamic_range_no_track = compute_dynamic_range(mean_displacement_no_track) | ||
metrics["No Tracking"] = { | ||
"dynamic_range": dynamic_range_no_track, | ||
"mean_displacement": mean_displacement_no_track, | ||
"std_displacement": std_displacement_no_track, | ||
} | ||
|
||
overlay_displacement_data["No Tracking"] = mean_displacement_no_track | ||
|
||
print("\nProcessing No Tracking dataset...") | ||
print(f"Dynamic Range for No Tracking: {dynamic_range_no_track}") | ||
|
||
plot_displacement(mean_displacement_no_track, std_displacement_no_track, "No Tracking") | ||
|
||
rms_values_no_track = compute_rms_per_track(embedding_dataset_no_track) | ||
datasets_rms["No Tracking"] = rms_values_no_track | ||
|
||
print(f"Plotting histogram of RMS values for No Tracking dataset...") | ||
plot_rms_histogram(rms_values_no_track, "No Tracking", bins=30) | ||
|
||
# Compute absolute differences for "No Tracking" | ||
abs_diff_no_track = np.concatenate( | ||
[ | ||
np.linalg.norm( | ||
np.diff(embedding_dataset_no_track["features"].values[indices], axis=0), | ||
axis=-1, | ||
) | ||
for indices in np.split( | ||
np.arange(len(embedding_dataset_no_track["track_id"])), | ||
np.where(np.diff(embedding_dataset_no_track["track_id"]) != 0)[0] + 1, | ||
) | ||
] | ||
) | ||
datasets_abs_diff["No Tracking"] = abs_diff_no_track | ||
|
||
# Process other datasets | ||
for label, path in feature_paths.items(): | ||
print(f"\nProcessing {label} dataset...") | ||
|
||
features_path = Path(path) | ||
embedding_dataset = read_embedding_dataset(features_path) | ||
|
||
mean_displacement, std_displacement = compute_displacement( | ||
embedding_dataset, max_tau=max_tau, return_mean_std=True | ||
) | ||
dynamic_range = compute_dynamic_range(mean_displacement) | ||
metrics[label] = { | ||
"dynamic_range": dynamic_range, | ||
"mean_displacement": mean_displacement, | ||
"std_displacement": std_displacement, | ||
} | ||
|
||
overlay_displacement_data[label] = mean_displacement | ||
|
||
print(f"Dynamic Range for {label}: {dynamic_range}") | ||
|
||
plot_displacement( | ||
mean_displacement, | ||
std_displacement, | ||
label, | ||
metrics_no_track=metrics.get("No Tracking", None), | ||
) | ||
|
||
rms_values = compute_rms_per_track(embedding_dataset) | ||
datasets_rms[label] = rms_values | ||
|
||
print(f"Plotting histogram of RMS values for {label}...") | ||
plot_rms_histogram(rms_values, label, bins=30) | ||
|
||
abs_diff = np.concatenate( | ||
[ | ||
np.linalg.norm( | ||
np.diff(embedding_dataset["features"].values[indices], axis=0), axis=-1 | ||
) | ||
for indices in np.split( | ||
np.arange(len(embedding_dataset["track_id"])), | ||
np.where(np.diff(embedding_dataset["track_id"]) != 0)[0] + 1, | ||
) | ||
] | ||
) | ||
datasets_abs_diff[label] = abs_diff | ||
|
||
print("\nPlotting overlayed displacement for all datasets...") | ||
plot_overlay_displacement(overlay_displacement_data) | ||
|
||
print("\nSummary of Dynamic Ranges:") | ||
for label, metric in metrics.items(): | ||
print(f"{label}: Dynamic Range = {metric['dynamic_range']}") | ||
|
||
print("\nPlotting RMS boxplot across models...") | ||
plot_boxplot_rms_across_models(datasets_rms) | ||
|
||
print("\nPlotting histograms of absolute differences across models...") | ||
plot_histogram_absolute_differences(datasets_abs_diff) | ||
|
||
|
||
# %% |
Oops, something went wrong.