Skip to content

Commit

Permalink
metrics for dynamic, smoothness and docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
alishbaimran committed Nov 19, 2024
1 parent 820c805 commit a73b9a0
Show file tree
Hide file tree
Showing 2 changed files with 232 additions and 3 deletions.
106 changes: 106 additions & 0 deletions applications/contrastive_phenotyping/evaluation/ALFI_displacement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# %%
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from viscy.representation.embedding_writer import read_embedding_dataset
from viscy.representation.evaluation.distance import (
calculate_normalized_euclidean_distance_cell,
compute_displacement_mean_std_full,
compute_dynamic_smoothness_metrics,
)

# %% Paths to datasets for different intervals
feature_paths = {
"7 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_7mins.zarr",
"21 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_21mins.zarr",
"28 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_28mins.zarr",
"56 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_56mins.zarr",
}

no_track_path = "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_classical.zarr"
cell_aware_path = "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_cellaware.zarr"

# Parameters
max_tau = 69

metrics = {}

features_path_no_track = Path(no_track_path)
embedding_dataset_no_track = read_embedding_dataset(features_path_no_track)

mean_displacement_no_track, std_displacement_no_track = compute_displacement_mean_std_full(embedding_dataset_no_track, max_tau)
dynamic_range_no_track, smoothness_no_track = compute_dynamic_smoothness_metrics(mean_displacement_no_track)

metrics["No Tracking"] = {
"dynamic_range": dynamic_range_no_track,
"smoothness": smoothness_no_track,
"mean_displacement": mean_displacement_no_track,
"std_displacement": std_displacement_no_track,
}

print("Metrics for No Tracking:")
print(f" Dynamic Range: {dynamic_range_no_track}")
print(f" Smoothness: {smoothness_no_track}")

features_path_cell_aware = Path(cell_aware_path)
embedding_dataset_cell_aware = read_embedding_dataset(features_path_cell_aware)

mean_displacement_cell_aware, std_displacement_cell_aware = compute_displacement_mean_std_full(embedding_dataset_cell_aware, max_tau)
dynamic_range_cell_aware, smoothness_cell_aware = compute_dynamic_smoothness_metrics(mean_displacement_cell_aware)

metrics["Cell Aware"] = {
"dynamic_range": dynamic_range_cell_aware,
"smoothness": smoothness_cell_aware,
"mean_displacement": mean_displacement_cell_aware,
"std_displacement": std_displacement_cell_aware,
}

print("Metrics for Cell Aware:")
print(f" Dynamic Range: {dynamic_range_cell_aware}")
print(f" Smoothness: {smoothness_cell_aware}")

for label, path in feature_paths.items():
features_path = Path(path)
embedding_dataset = read_embedding_dataset(features_path)

mean_displacement, std_displacement = compute_displacement_mean_std_full(embedding_dataset, max_tau)
dynamic_range, smoothness = compute_dynamic_smoothness_metrics(mean_displacement)

metrics[label] = {
"dynamic_range": dynamic_range,
"smoothness": smoothness,
"mean_displacement": mean_displacement,
"std_displacement": std_displacement,
}

plt.figure(figsize=(10, 6))
taus = list(mean_displacement.keys())
mean_values = list(mean_displacement.values())
std_values = list(std_displacement.values())

plt.plot(taus, mean_values, marker='o', label=f'{label}', color='green')
plt.fill_between(taus,
np.array(mean_values) - np.array(std_values),
np.array(mean_values) + np.array(std_values),
color='green', alpha=0.3, label=f'Std Dev ({label})')

mean_values_no_track = list(metrics["No Tracking"]["mean_displacement"].values())
std_values_no_track = list(metrics["No Tracking"]["std_displacement"].values())

plt.plot(taus, mean_values_no_track, marker='o', label='Classical Contrastive (No Tracking)', color='blue')
plt.fill_between(taus,
np.array(mean_values_no_track) - np.array(std_values_no_track),
np.array(mean_values_no_track) + np.array(std_values_no_track),
color='blue', alpha=0.3, label='Std Dev (No Tracking)')

plt.xlabel('Time Shift (τ)')
plt.ylabel('Euclidean Distance')
plt.title(f'Embedding Displacement Over Time ({label})')
plt.grid(True)
plt.legend()
plt.show()

print(f"Metrics for {label}:")
print(f" Dynamic Range: {dynamic_range}")
print(f" Smoothness: {smoothness}")
# %%
129 changes: 126 additions & 3 deletions viscy/representation/evaluation/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,23 @@


def calculate_cosine_similarity_cell(embedding_dataset, fov_name, track_id):
"""Extract embeddings and calculate cosine similarities for a specific cell"""
"""
Calculate cosine similarities for a specific cell over time.
Parameters:
embedding_dataset : xarray.Dataset
The dataset containing embeddings, timepoints, fov_name, and track_id.
fov_name : str
Field of view name to identify the specific cell.
track_id : int
Track ID to identify the specific cell.
Returns:
tuple
- time_points (array): Array of time points for the cell.
- cosine_similarities (list): Cosine similarities between the embedding
at the first time point and each subsequent time point.
"""
# Filter the dataset for the specific infected cell
filtered_data = embedding_dataset.where(
(embedding_dataset["fov_name"] == fov_name)
Expand Down Expand Up @@ -34,7 +50,25 @@ def calculate_cosine_similarity_cell(embedding_dataset, fov_name, track_id):
def compute_displacement_mean_std(
embedding_dataset, max_tau=10, use_cosine=False, use_dissimilarity=False
):
"""Compute the norm of differences between embeddings at t and t + tau"""
"""
Compute the mean and standard deviation of displacements between embeddings at time t and t + tau for each tau.
Parameters:
embedding_dataset : xarray.Dataset
The dataset containing embeddings, timepoints, fov_name, and track_id.
max_tau : int, optional
The maximum tau value to compute displacements for. Default is 10.
use_cosine : bool, optional
If True, compute cosine similarity instead of Euclidean distance. Default is False.
use_dissimilarity : bool, optional
If True and use_cosine is True, compute cosine dissimilarity (1 - similarity).
Default is False.
Returns:
tuple
- mean_displacement_per_tau (dict): Mean displacement for each tau.
- std_displacement_per_tau (dict): Standard deviation of displacements for each tau.
"""
# Get the arrays of (fov_name, track_id, t, and embeddings)
fov_names = embedding_dataset["fov_name"].values
track_ids = embedding_dataset["track_id"].values
Expand Down Expand Up @@ -104,7 +138,27 @@ def compute_displacement(
use_dissimilarity=False,
use_umap=False,
):
"""Compute the norm of differences between embeddings at t and t + tau"""
"""
Compute the displacements between embeddings at time t and t + tau for each tau.
Parameters:
embedding_dataset : xarray.Dataset
The dataset containing embeddings, timepoints, fov_name, and track_id.
max_tau : int, optional
The maximum tau value to compute displacements for. Default is 10.
use_cosine : bool, optional
If True, compute cosine similarity instead of Euclidean distance. Default is False.
use_dissimilarity : bool, optional
If True and use_cosine is True, compute cosine dissimilarity (1 - similarity).
Default is False.
use_umap : bool, optional
If True, use UMAP embeddings instead of feature embeddings. Default is False.
Returns:
dict
A dictionary where the key is tau and the value is a list of displacements
for all cells at that tau.
"""
# Get the arrays of (fov_name, track_id, t, and embeddings)
fov_names = embedding_dataset["fov_name"].values
track_ids = embedding_dataset["track_id"].values
Expand Down Expand Up @@ -164,6 +218,23 @@ def compute_displacement(


def calculate_normalized_euclidean_distance_cell(embedding_dataset, fov_name, track_id):
"""
Calculate the normalized Euclidean distance between the embedding at the first time point and each subsequent time point for a specific cell.
Parameters:
embedding_dataset : xarray.Dataset
The dataset containing embeddings, timepoints, fov_name, and track_id.
fov_name : str
Field of view name to identify the specific cell.
track_id : int
Track ID to identify the specific cell.
Returns:
tuple
- time_points (array): Array of time points for the cell.
- euclidean_distances (list): Normalized Euclidean distances between the embedding
at the first time point and each subsequent time point.
"""
filtered_data = embedding_dataset.where(
(embedding_dataset["fov_name"] == fov_name)
& (embedding_dataset["track_id"] == track_id),
Expand All @@ -189,6 +260,20 @@ def calculate_normalized_euclidean_distance_cell(embedding_dataset, fov_name, tr


def compute_displacement_mean_std_full(embedding_dataset, max_tau=10):
"""
Compute the mean and standard deviation of displacements between embeddings at time t and t + tau for all cells.
Parameters:
embedding_dataset : xarray.Dataset
The dataset containing embeddings, timepoints, fov_name, and track_id.
max_tau : int, optional
The maximum tau value to compute displacements for. Default is 10.
Returns:
tuple
- mean_displacement_per_tau (dict): Mean displacement for each tau.
- std_displacement_per_tau (dict): Standard deviation of displacements for each tau.
"""
fov_names = embedding_dataset["fov_name"].values
track_ids = embedding_dataset["track_id"].values
timepoints = embedding_dataset["t"].values
Expand Down Expand Up @@ -247,3 +332,41 @@ def compute_displacement_mean_std_full(embedding_dataset, max_tau=10):
}

return mean_displacement_per_tau, std_displacement_per_tau


# Function to compute metrics for dynamic range and smoothness
def compute_dynamic_smoothness_metrics(mean_displacement_per_tau):
"""
Compute dynamic range and smoothness metrics for displacement curves.
Parameters:
mean_displacement_per_tau: dict with tau as key and mean displacement as value
Returns:
tuple: (dynamic_range, smoothness)
- dynamic_range: max displacement - min displacement
- smoothness: RMS of second differences of normalized curve
"""
taus = np.array(sorted(mean_displacement_per_tau.keys()))
displacements = np.array([mean_displacement_per_tau[tau] for tau in taus])

dynamic_range = np.max(displacements) - np.min(displacements)

if np.max(displacements) != np.min(displacements):
displacements_normalized = (displacements - np.min(displacements)) / (
np.max(displacements) - np.min(displacements)
)
else:
displacements_normalized = displacements - np.min(
displacements
) # Handle constant case

first_diff = np.diff(displacements_normalized)

second_diff = np.diff(first_diff)

# Compute RMS of second differences as smoothness metric
# Lower values indicate smoother curves
smoothness = np.sqrt(np.mean(second_diff**2))

return dynamic_range, smoothness

0 comments on commit a73b9a0

Please sign in to comment.