Skip to content

Commit

Permalink
- proofreading the calculations
Browse files Browse the repository at this point in the history
- removing unecessary calls to ALFI script
- simplifying code to re-use functions
  • Loading branch information
edyoshikun committed Nov 20, 2024
1 parent 3f8363d commit 26ebe74
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 456 deletions.
233 changes: 50 additions & 183 deletions applications/contrastive_phenotyping/evaluation/ALFI_displacement.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from viscy.representation.embedding_writer import read_embedding_dataset
from viscy.representation.evaluation.distance import (
calculate_normalized_euclidean_distance_cell,
compute_displacement,
compute_dynamic_range,
compute_rms_per_track,
)
from collections import defaultdict
from tabulate import tabulate
Expand All @@ -13,156 +16,10 @@
from sklearn.metrics.pairwise import cosine_similarity
from collections import OrderedDict

# %% function
# %% function

def compute_displacement_mean_std_full(embedding_dataset, max_tau=10):
"""
Compute the mean and standard deviation of displacements between embeddings at time t and t + tau for all cells.
Parameters:
embedding_dataset : xarray.Dataset
The dataset containing embeddings, timepoints, fov_name, and track_id.
max_tau : int, optional
The maximum tau value to compute displacements for. Default is 10.
Returns:
tuple
- mean_displacement_per_tau (dict): Mean displacement for each tau.
- std_displacement_per_tau (dict): Standard deviation of displacements for each tau.
"""
fov_names = embedding_dataset["fov_name"].values
track_ids = embedding_dataset["track_id"].values
timepoints = embedding_dataset["t"].values
embeddings = embedding_dataset["features"].values

cell_identifiers = np.array(
list(zip(fov_names, track_ids)),
dtype=[("fov_name", "O"), ("track_id", "int64")],
)

unique_cells = np.unique(cell_identifiers)

displacement_per_tau = defaultdict(list)

for cell in unique_cells:
fov_name = cell["fov_name"]
track_id = cell["track_id"]

indices = np.where((fov_names == fov_name) & (track_ids == track_id))[0]

cell_timepoints = timepoints[indices]
cell_embeddings = embeddings[indices]

sorted_indices = np.argsort(cell_timepoints)
cell_timepoints = cell_timepoints[sorted_indices]

cell_embeddings = cell_embeddings[sorted_indices]

for i in range(len(cell_timepoints)):
current_time = cell_timepoints[i]

current_embedding = cell_embeddings[i]

current_embedding = current_embedding / np.linalg.norm(current_embedding)

for tau in range(0, max_tau + 1):
future_time = current_time + tau


future_index = np.where(cell_timepoints == future_time)[0]

if len(future_index) >= 1:

future_embedding = cell_embeddings[future_index[0]]
future_embedding = future_embedding / np.linalg.norm(
future_embedding
)

distance = np.linalg.norm(current_embedding - future_embedding)

displacement_per_tau[tau].append(distance)

mean_displacement_per_tau = {
tau: np.mean(displacements)
for tau, displacements in displacement_per_tau.items()
}
std_displacement_per_tau = {
tau: np.std(displacements)
for tau, displacements in displacement_per_tau.items()
}

return mean_displacement_per_tau, std_displacement_per_tau


def compute_dynamic_range(mean_displacement_per_tau):
"""
Compute the dynamic range as the difference between the maximum
and minimum mean displacement per τ.
Parameters:
mean_displacement_per_tau: dict with τ as key and mean displacement as value
Returns:
float: dynamic range (max displacement - min displacement)
"""
displacements = mean_displacement_per_tau.values()
return max(displacements) - min(displacements)


def compute_rms_per_track(embedding_dataset):
"""
Compute RMS of the time derivative of embeddings per track.
Parameters:
embedding_dataset : xarray.Dataset
The dataset containing embeddings, timepoints, fov_name, and track_id.
Returns:
list: A list of RMS values, one for each track.
"""
fov_names = embedding_dataset["fov_name"].values
track_ids = embedding_dataset["track_id"].values
timepoints = embedding_dataset["t"].values
embeddings = embedding_dataset["features"].values

cell_identifiers = np.array(
list(zip(fov_names, track_ids)),
dtype=[("fov_name", "O"), ("track_id", "int64")],
)

unique_cells = np.unique(cell_identifiers)

rms_values = []

for cell in unique_cells:
fov_name = cell["fov_name"]
track_id = cell["track_id"]

indices = np.where((fov_names == fov_name) & (track_ids == track_id))[0]

cell_timepoints = timepoints[indices]
cell_embeddings = embeddings[indices]
#print(cell_embeddings.shape)

if len(cell_embeddings) < 2:
continue

sorted_indices = np.argsort(cell_timepoints)
cell_embeddings = cell_embeddings[sorted_indices]

# Compute differences between consecutive embeddings
differences = np.diff(cell_embeddings, axis=0) # Shape: (T-1, 768)

if differences.shape[0] == 0:
continue

# Compute RMS for this track
norms = np.linalg.norm(differences, axis=1)
if len(norms) > 0:
rms = np.sqrt(np.mean(norms**2))
rms_values.append(rms)

return rms_values
# Removed redundant compute_displacement_mean_std_full function
# Removed redundant compute_dynamic_range and compute_rms_per_track functions


def plot_rms_histogram(rms_values, label, bins=30):
Expand All @@ -188,7 +45,10 @@ def plot_rms_histogram(rms_values, label, bins=30):
plt.grid(True)
plt.show()

def plot_displacement(mean_displacement, std_displacement, label, metrics_no_track=None):

def plot_displacement(
mean_displacement, std_displacement, label, metrics_no_track=None
):
"""
Plot embedding displacement over time with mean and standard deviation.
Expand Down Expand Up @@ -247,6 +107,7 @@ def plot_displacement(mean_displacement, std_displacement, label, metrics_no_tra
plt.legend(fontsize=12)
plt.show()


def plot_overlay_displacement(overlay_displacement_data):
"""
Plot embedding displacement over time for all datasets in one plot.
Expand All @@ -263,15 +124,16 @@ def plot_overlay_displacement(overlay_displacement_data):
taus = list(mean_displacement.keys())
mean_values = list(mean_displacement.values())
plt.plot(taus, mean_values, marker="o", label=label)

plt.xlabel("Time Shift (τ)", fontsize=14)
plt.ylabel("Euclidean Distance", fontsize=14)
plt.title("Overlayed Embedding Displacement Over Time", fontsize=16)
plt.grid(True)
plt.legend(fontsize=12)
plt.show()

# %% hist stats

# %% hist stats
def plot_boxplot_rms_across_models(datasets_rms):
"""
Plot a boxplot for the distribution of RMS values across models.
Expand All @@ -290,11 +152,13 @@ def plot_boxplot_rms_across_models(datasets_rms):
print(data)
# Plot the boxplot
plt.boxplot(data, tick_labels=labels, patch_artist=True, showmeans=True)

plt.title("Distribution of RMS of Rate of Change of Embedding Across Models", fontsize=16)

plt.title(
"Distribution of RMS of Rate of Change of Embedding Across Models", fontsize=16
)
plt.ylabel("RMS of Time Derivative", fontsize=14)
plt.xticks(rotation=45, fontsize=12)
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.grid(axis="y", linestyle="--", alpha=0.7)
plt.tight_layout()
plt.show()

Expand All @@ -313,7 +177,7 @@ def plot_histogram_absolute_differences(datasets_abs_diff):
plt.figure(figsize=(12, 6))
for label, abs_diff in datasets_abs_diff.items():
plt.hist(abs_diff, bins=50, alpha=0.5, label=label, density=True)

plt.title("Histograms of Absolute Differences Across Models", fontsize=16)
plt.xlabel("Absolute Difference", fontsize=14)
plt.ylabel("Density", fontsize=14)
Expand All @@ -322,6 +186,7 @@ def plot_histogram_absolute_differences(datasets_abs_diff):
plt.tight_layout()
plt.show()


# %% Paths to datasets
feature_paths = {
"7 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_7mins.zarr",
Expand All @@ -345,8 +210,8 @@ def plot_histogram_absolute_differences(datasets_abs_diff):
features_path_no_track = Path(no_track_path)
embedding_dataset_no_track = read_embedding_dataset(features_path_no_track)

mean_displacement_no_track, std_displacement_no_track = compute_displacement_mean_std_full(
embedding_dataset_no_track, max_tau=max_tau
mean_displacement_no_track, std_displacement_no_track = compute_displacement(
embedding_dataset_no_track, max_tau=max_tau, return_mean_std=True
)
dynamic_range_no_track = compute_dynamic_range(mean_displacement_no_track)
metrics["No Tracking"] = {
Expand All @@ -360,11 +225,7 @@ def plot_histogram_absolute_differences(datasets_abs_diff):
print("\nProcessing No Tracking dataset...")
print(f"Dynamic Range for No Tracking: {dynamic_range_no_track}")

plot_displacement(
mean_displacement_no_track,
std_displacement_no_track,
"No Tracking"
)
plot_displacement(mean_displacement_no_track, std_displacement_no_track, "No Tracking")

rms_values_no_track = compute_rms_per_track(embedding_dataset_no_track)
datasets_rms["No Tracking"] = rms_values_no_track
Expand All @@ -373,14 +234,18 @@ def plot_histogram_absolute_differences(datasets_abs_diff):
plot_rms_histogram(rms_values_no_track, "No Tracking", bins=30)

# Compute absolute differences for "No Tracking"
abs_diff_no_track = np.concatenate([
np.linalg.norm(
np.diff(embedding_dataset_no_track["features"].values[indices], axis=0),
axis=-1
)
for indices in np.split(np.arange(len(embedding_dataset_no_track["track_id"])),
np.where(np.diff(embedding_dataset_no_track["track_id"]) != 0)[0] + 1)
])
abs_diff_no_track = np.concatenate(
[
np.linalg.norm(
np.diff(embedding_dataset_no_track["features"].values[indices], axis=0),
axis=-1,
)
for indices in np.split(
np.arange(len(embedding_dataset_no_track["track_id"])),
np.where(np.diff(embedding_dataset_no_track["track_id"]) != 0)[0] + 1,
)
]
)
datasets_abs_diff["No Tracking"] = abs_diff_no_track

# Process other datasets
Expand All @@ -390,8 +255,8 @@ def plot_histogram_absolute_differences(datasets_abs_diff):
features_path = Path(path)
embedding_dataset = read_embedding_dataset(features_path)

mean_displacement, std_displacement = compute_displacement_mean_std_full(
embedding_dataset, max_tau=max_tau
mean_displacement, std_displacement = compute_displacement(
embedding_dataset, max_tau=max_tau, return_mean_std=True
)
dynamic_range = compute_dynamic_range(mean_displacement)
metrics[label] = {
Expand All @@ -417,16 +282,17 @@ def plot_histogram_absolute_differences(datasets_abs_diff):
print(f"Plotting histogram of RMS values for {label}...")
plot_rms_histogram(rms_values, label, bins=30)

abs_diff = np.concatenate([
np.linalg.norm(
np.diff(embedding_dataset["features"].values[indices], axis=0),
axis=-1
)
for indices in np.split(
np.arange(len(embedding_dataset["track_id"])),
np.where(np.diff(embedding_dataset["track_id"]) != 0)[0] + 1
)
])
abs_diff = np.concatenate(
[
np.linalg.norm(
np.diff(embedding_dataset["features"].values[indices], axis=0), axis=-1
)
for indices in np.split(
np.arange(len(embedding_dataset["track_id"])),
np.where(np.diff(embedding_dataset["track_id"]) != 0)[0] + 1,
)
]
)
datasets_abs_diff[label] = abs_diff

print("\nPlotting overlayed displacement for all datasets...")
Expand All @@ -443,3 +309,4 @@ def plot_histogram_absolute_differences(datasets_abs_diff):
plot_histogram_absolute_differences(datasets_abs_diff)


# %%
Loading

0 comments on commit 26ebe74

Please sign in to comment.