From db808191096e94c9b1a5ade8c7ccdec50477d831 Mon Sep 17 00:00:00 2001 From: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> Date: Fri, 8 Nov 2024 09:47:56 -0800 Subject: [PATCH 01/44] translation: fix validation loss aggregation (#202) --- viscy/translation/engine.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/viscy/translation/engine.py b/viscy/translation/engine.py index aa7ac24b..698ff412 100644 --- a/viscy/translation/engine.py +++ b/viscy/translation/engine.py @@ -378,7 +378,6 @@ def on_train_epoch_end(self): def on_validation_epoch_end(self): super().on_validation_epoch_end() self._log_samples("val_samples", self.validation_step_outputs) - self.validation_step_outputs = [] # average within each dataloader loss_means = [torch.tensor(losses).mean() for losses in self.validation_losses] self.log( @@ -386,6 +385,8 @@ def on_validation_epoch_end(self): torch.tensor(loss_means).mean().to(self.device), sync_dist=True, ) + self.validation_step_outputs.clear() + self.validation_losses.clear() def on_test_start(self): """Load CellPose model for segmentation.""" From 820c805e7587812dbc8673c2b063701ac8936868 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Tue, 12 Nov 2024 16:48:55 -0800 Subject: [PATCH 02/44] exposing prefetch and persistent worker (#203) --- viscy/data/hcs.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index 9ee37185..8831bdd7 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -318,6 +318,8 @@ def __init__( augmentations: list[MapTransform] = [], caching: bool = False, ground_truth_masks: Path | None = None, + persistent_workers=False, + prefetch_factor=None, ): super().__init__() self.data_path = Path(data_path) @@ -334,6 +336,8 @@ def __init__( self.caching = caching self.ground_truth_masks = ground_truth_masks self.prepare_data_per_node = True + self.persistent_workers = persistent_workers + self.prefetch_factor = prefetch_factor @property def cache_path(self): @@ -521,8 +525,8 @@ def train_dataloader(self): batch_size=self.batch_size // self.train_patches_per_stack, num_workers=self.num_workers, shuffle=True, - persistent_workers=bool(self.num_workers), - prefetch_factor=4 if self.num_workers else None, + prefetch_factor=self.prefetch_factor if self.num_workers else None, + persistent_workers=self.persistent_workers, collate_fn=_collate_samples, drop_last=True, ) @@ -533,8 +537,8 @@ def val_dataloader(self): batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, - prefetch_factor=4 if self.num_workers else None, - persistent_workers=bool(self.num_workers), + prefetch_factor=self.prefetch_factor if self.num_workers else None, + persistent_workers=self.persistent_workers, ) def test_dataloader(self): From a73b9a05cd2b85844e38eb7387e6b3ab4855bb18 Mon Sep 17 00:00:00 2001 From: Alishba Imran Date: Mon, 18 Nov 2024 17:55:17 -0800 Subject: [PATCH 03/44] metrics for dynamic, smoothness and docstrings --- .../evaluation/ALFI_displacement.py | 106 ++++++++++++++ viscy/representation/evaluation/distance.py | 129 +++++++++++++++++- 2 files changed, 232 insertions(+), 3 deletions(-) create mode 100644 applications/contrastive_phenotyping/evaluation/ALFI_displacement.py diff --git a/applications/contrastive_phenotyping/evaluation/ALFI_displacement.py b/applications/contrastive_phenotyping/evaluation/ALFI_displacement.py new file mode 100644 index 00000000..363f1603 --- /dev/null +++ b/applications/contrastive_phenotyping/evaluation/ALFI_displacement.py @@ -0,0 +1,106 @@ +# %% +from pathlib import Path +import matplotlib.pyplot as plt +import numpy as np +from viscy.representation.embedding_writer import read_embedding_dataset +from viscy.representation.evaluation.distance import ( + calculate_normalized_euclidean_distance_cell, + compute_displacement_mean_std_full, + compute_dynamic_smoothness_metrics, +) + +# %% Paths to datasets for different intervals +feature_paths = { + "7 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_7mins.zarr", + "21 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_21mins.zarr", + "28 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_28mins.zarr", + "56 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_56mins.zarr", +} + +no_track_path = "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_classical.zarr" +cell_aware_path = "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_cellaware.zarr" + +# Parameters +max_tau = 69 + +metrics = {} + +features_path_no_track = Path(no_track_path) +embedding_dataset_no_track = read_embedding_dataset(features_path_no_track) + +mean_displacement_no_track, std_displacement_no_track = compute_displacement_mean_std_full(embedding_dataset_no_track, max_tau) +dynamic_range_no_track, smoothness_no_track = compute_dynamic_smoothness_metrics(mean_displacement_no_track) + +metrics["No Tracking"] = { + "dynamic_range": dynamic_range_no_track, + "smoothness": smoothness_no_track, + "mean_displacement": mean_displacement_no_track, + "std_displacement": std_displacement_no_track, +} + +print("Metrics for No Tracking:") +print(f" Dynamic Range: {dynamic_range_no_track}") +print(f" Smoothness: {smoothness_no_track}") + +features_path_cell_aware = Path(cell_aware_path) +embedding_dataset_cell_aware = read_embedding_dataset(features_path_cell_aware) + +mean_displacement_cell_aware, std_displacement_cell_aware = compute_displacement_mean_std_full(embedding_dataset_cell_aware, max_tau) +dynamic_range_cell_aware, smoothness_cell_aware = compute_dynamic_smoothness_metrics(mean_displacement_cell_aware) + +metrics["Cell Aware"] = { + "dynamic_range": dynamic_range_cell_aware, + "smoothness": smoothness_cell_aware, + "mean_displacement": mean_displacement_cell_aware, + "std_displacement": std_displacement_cell_aware, +} + +print("Metrics for Cell Aware:") +print(f" Dynamic Range: {dynamic_range_cell_aware}") +print(f" Smoothness: {smoothness_cell_aware}") + +for label, path in feature_paths.items(): + features_path = Path(path) + embedding_dataset = read_embedding_dataset(features_path) + + mean_displacement, std_displacement = compute_displacement_mean_std_full(embedding_dataset, max_tau) + dynamic_range, smoothness = compute_dynamic_smoothness_metrics(mean_displacement) + + metrics[label] = { + "dynamic_range": dynamic_range, + "smoothness": smoothness, + "mean_displacement": mean_displacement, + "std_displacement": std_displacement, + } + + plt.figure(figsize=(10, 6)) + taus = list(mean_displacement.keys()) + mean_values = list(mean_displacement.values()) + std_values = list(std_displacement.values()) + + plt.plot(taus, mean_values, marker='o', label=f'{label}', color='green') + plt.fill_between(taus, + np.array(mean_values) - np.array(std_values), + np.array(mean_values) + np.array(std_values), + color='green', alpha=0.3, label=f'Std Dev ({label})') + + mean_values_no_track = list(metrics["No Tracking"]["mean_displacement"].values()) + std_values_no_track = list(metrics["No Tracking"]["std_displacement"].values()) + + plt.plot(taus, mean_values_no_track, marker='o', label='Classical Contrastive (No Tracking)', color='blue') + plt.fill_between(taus, + np.array(mean_values_no_track) - np.array(std_values_no_track), + np.array(mean_values_no_track) + np.array(std_values_no_track), + color='blue', alpha=0.3, label='Std Dev (No Tracking)') + + plt.xlabel('Time Shift (τ)') + plt.ylabel('Euclidean Distance') + plt.title(f'Embedding Displacement Over Time ({label})') + plt.grid(True) + plt.legend() + plt.show() + + print(f"Metrics for {label}:") + print(f" Dynamic Range: {dynamic_range}") + print(f" Smoothness: {smoothness}") +# %% diff --git a/viscy/representation/evaluation/distance.py b/viscy/representation/evaluation/distance.py index 5c8f7759..2fca9b79 100644 --- a/viscy/representation/evaluation/distance.py +++ b/viscy/representation/evaluation/distance.py @@ -5,7 +5,23 @@ def calculate_cosine_similarity_cell(embedding_dataset, fov_name, track_id): - """Extract embeddings and calculate cosine similarities for a specific cell""" + """ + Calculate cosine similarities for a specific cell over time. + + Parameters: + embedding_dataset : xarray.Dataset + The dataset containing embeddings, timepoints, fov_name, and track_id. + fov_name : str + Field of view name to identify the specific cell. + track_id : int + Track ID to identify the specific cell. + + Returns: + tuple + - time_points (array): Array of time points for the cell. + - cosine_similarities (list): Cosine similarities between the embedding + at the first time point and each subsequent time point. + """ # Filter the dataset for the specific infected cell filtered_data = embedding_dataset.where( (embedding_dataset["fov_name"] == fov_name) @@ -34,7 +50,25 @@ def calculate_cosine_similarity_cell(embedding_dataset, fov_name, track_id): def compute_displacement_mean_std( embedding_dataset, max_tau=10, use_cosine=False, use_dissimilarity=False ): - """Compute the norm of differences between embeddings at t and t + tau""" + """ + Compute the mean and standard deviation of displacements between embeddings at time t and t + tau for each tau. + + Parameters: + embedding_dataset : xarray.Dataset + The dataset containing embeddings, timepoints, fov_name, and track_id. + max_tau : int, optional + The maximum tau value to compute displacements for. Default is 10. + use_cosine : bool, optional + If True, compute cosine similarity instead of Euclidean distance. Default is False. + use_dissimilarity : bool, optional + If True and use_cosine is True, compute cosine dissimilarity (1 - similarity). + Default is False. + + Returns: + tuple + - mean_displacement_per_tau (dict): Mean displacement for each tau. + - std_displacement_per_tau (dict): Standard deviation of displacements for each tau. + """ # Get the arrays of (fov_name, track_id, t, and embeddings) fov_names = embedding_dataset["fov_name"].values track_ids = embedding_dataset["track_id"].values @@ -104,7 +138,27 @@ def compute_displacement( use_dissimilarity=False, use_umap=False, ): - """Compute the norm of differences between embeddings at t and t + tau""" + """ + Compute the displacements between embeddings at time t and t + tau for each tau. + + Parameters: + embedding_dataset : xarray.Dataset + The dataset containing embeddings, timepoints, fov_name, and track_id. + max_tau : int, optional + The maximum tau value to compute displacements for. Default is 10. + use_cosine : bool, optional + If True, compute cosine similarity instead of Euclidean distance. Default is False. + use_dissimilarity : bool, optional + If True and use_cosine is True, compute cosine dissimilarity (1 - similarity). + Default is False. + use_umap : bool, optional + If True, use UMAP embeddings instead of feature embeddings. Default is False. + + Returns: + dict + A dictionary where the key is tau and the value is a list of displacements + for all cells at that tau. + """ # Get the arrays of (fov_name, track_id, t, and embeddings) fov_names = embedding_dataset["fov_name"].values track_ids = embedding_dataset["track_id"].values @@ -164,6 +218,23 @@ def compute_displacement( def calculate_normalized_euclidean_distance_cell(embedding_dataset, fov_name, track_id): + """ + Calculate the normalized Euclidean distance between the embedding at the first time point and each subsequent time point for a specific cell. + + Parameters: + embedding_dataset : xarray.Dataset + The dataset containing embeddings, timepoints, fov_name, and track_id. + fov_name : str + Field of view name to identify the specific cell. + track_id : int + Track ID to identify the specific cell. + + Returns: + tuple + - time_points (array): Array of time points for the cell. + - euclidean_distances (list): Normalized Euclidean distances between the embedding + at the first time point and each subsequent time point. + """ filtered_data = embedding_dataset.where( (embedding_dataset["fov_name"] == fov_name) & (embedding_dataset["track_id"] == track_id), @@ -189,6 +260,20 @@ def calculate_normalized_euclidean_distance_cell(embedding_dataset, fov_name, tr def compute_displacement_mean_std_full(embedding_dataset, max_tau=10): + """ + Compute the mean and standard deviation of displacements between embeddings at time t and t + tau for all cells. + + Parameters: + embedding_dataset : xarray.Dataset + The dataset containing embeddings, timepoints, fov_name, and track_id. + max_tau : int, optional + The maximum tau value to compute displacements for. Default is 10. + + Returns: + tuple + - mean_displacement_per_tau (dict): Mean displacement for each tau. + - std_displacement_per_tau (dict): Standard deviation of displacements for each tau. + """ fov_names = embedding_dataset["fov_name"].values track_ids = embedding_dataset["track_id"].values timepoints = embedding_dataset["t"].values @@ -247,3 +332,41 @@ def compute_displacement_mean_std_full(embedding_dataset, max_tau=10): } return mean_displacement_per_tau, std_displacement_per_tau + + +# Function to compute metrics for dynamic range and smoothness +def compute_dynamic_smoothness_metrics(mean_displacement_per_tau): + """ + Compute dynamic range and smoothness metrics for displacement curves. + + Parameters: + mean_displacement_per_tau: dict with tau as key and mean displacement as value + + Returns: + tuple: (dynamic_range, smoothness) + - dynamic_range: max displacement - min displacement + - smoothness: RMS of second differences of normalized curve + """ + taus = np.array(sorted(mean_displacement_per_tau.keys())) + displacements = np.array([mean_displacement_per_tau[tau] for tau in taus]) + + dynamic_range = np.max(displacements) - np.min(displacements) + + if np.max(displacements) != np.min(displacements): + displacements_normalized = (displacements - np.min(displacements)) / ( + np.max(displacements) - np.min(displacements) + ) + else: + displacements_normalized = displacements - np.min( + displacements + ) # Handle constant case + + first_diff = np.diff(displacements_normalized) + + second_diff = np.diff(first_diff) + + # Compute RMS of second differences as smoothness metric + # Lower values indicate smoother curves + smoothness = np.sqrt(np.mean(second_diff**2)) + + return dynamic_range, smoothness From 79db1e458cc19841bdd353f152f1c60bc96ba9da Mon Sep 17 00:00:00 2001 From: Alishba Imran Date: Tue, 19 Nov 2024 18:56:16 -0800 Subject: [PATCH 04/44] updated metrics and plots for distance --- .../evaluation/ALFI_displacement.py | 457 +++++++++++++++--- viscy/representation/evaluation/distance.py | 4 +- 2 files changed, 401 insertions(+), 60 deletions(-) diff --git a/applications/contrastive_phenotyping/evaluation/ALFI_displacement.py b/applications/contrastive_phenotyping/evaluation/ALFI_displacement.py index 363f1603..ba36e84d 100644 --- a/applications/contrastive_phenotyping/evaluation/ALFI_displacement.py +++ b/applications/contrastive_phenotyping/evaluation/ALFI_displacement.py @@ -5,102 +5,441 @@ from viscy.representation.embedding_writer import read_embedding_dataset from viscy.representation.evaluation.distance import ( calculate_normalized_euclidean_distance_cell, - compute_displacement_mean_std_full, - compute_dynamic_smoothness_metrics, ) +from collections import defaultdict +from tabulate import tabulate -# %% Paths to datasets for different intervals +import numpy as np +from sklearn.metrics.pairwise import cosine_similarity +from collections import OrderedDict + +# %% function + +def compute_displacement_mean_std_full(embedding_dataset, max_tau=10): + """ + Compute the mean and standard deviation of displacements between embeddings at time t and t + tau for all cells. + + Parameters: + embedding_dataset : xarray.Dataset + The dataset containing embeddings, timepoints, fov_name, and track_id. + max_tau : int, optional + The maximum tau value to compute displacements for. Default is 10. + + Returns: + tuple + - mean_displacement_per_tau (dict): Mean displacement for each tau. + - std_displacement_per_tau (dict): Standard deviation of displacements for each tau. + """ + fov_names = embedding_dataset["fov_name"].values + track_ids = embedding_dataset["track_id"].values + timepoints = embedding_dataset["t"].values + embeddings = embedding_dataset["features"].values + + cell_identifiers = np.array( + list(zip(fov_names, track_ids)), + dtype=[("fov_name", "O"), ("track_id", "int64")], + ) + + unique_cells = np.unique(cell_identifiers) + + displacement_per_tau = defaultdict(list) + + for cell in unique_cells: + fov_name = cell["fov_name"] + track_id = cell["track_id"] + + indices = np.where((fov_names == fov_name) & (track_ids == track_id))[0] + + cell_timepoints = timepoints[indices] + cell_embeddings = embeddings[indices] + + sorted_indices = np.argsort(cell_timepoints) + cell_timepoints = cell_timepoints[sorted_indices] + + cell_embeddings = cell_embeddings[sorted_indices] + + for i in range(len(cell_timepoints)): + current_time = cell_timepoints[i] + + current_embedding = cell_embeddings[i] + + current_embedding = current_embedding / np.linalg.norm(current_embedding) + + for tau in range(0, max_tau + 1): + future_time = current_time + tau + + + future_index = np.where(cell_timepoints == future_time)[0] + + if len(future_index) >= 1: + + future_embedding = cell_embeddings[future_index[0]] + future_embedding = future_embedding / np.linalg.norm( + future_embedding + ) + + distance = np.linalg.norm(current_embedding - future_embedding) + + displacement_per_tau[tau].append(distance) + + mean_displacement_per_tau = { + tau: np.mean(displacements) + for tau, displacements in displacement_per_tau.items() + } + std_displacement_per_tau = { + tau: np.std(displacements) + for tau, displacements in displacement_per_tau.items() + } + + return mean_displacement_per_tau, std_displacement_per_tau + + +def compute_dynamic_range(mean_displacement_per_tau): + """ + Compute the dynamic range as the difference between the maximum + and minimum mean displacement per τ. + + Parameters: + mean_displacement_per_tau: dict with τ as key and mean displacement as value + + Returns: + float: dynamic range (max displacement - min displacement) + """ + displacements = mean_displacement_per_tau.values() + return max(displacements) - min(displacements) + + +def compute_rms_per_track(embedding_dataset): + """ + Compute RMS of the time derivative of embeddings per track. + + Parameters: + embedding_dataset : xarray.Dataset + The dataset containing embeddings, timepoints, fov_name, and track_id. + + Returns: + list: A list of RMS values, one for each track. + """ + fov_names = embedding_dataset["fov_name"].values + track_ids = embedding_dataset["track_id"].values + timepoints = embedding_dataset["t"].values + embeddings = embedding_dataset["features"].values + + cell_identifiers = np.array( + list(zip(fov_names, track_ids)), + dtype=[("fov_name", "O"), ("track_id", "int64")], + ) + + unique_cells = np.unique(cell_identifiers) + + rms_values = [] + + for cell in unique_cells: + fov_name = cell["fov_name"] + track_id = cell["track_id"] + + indices = np.where((fov_names == fov_name) & (track_ids == track_id))[0] + + cell_timepoints = timepoints[indices] + cell_embeddings = embeddings[indices] + #print(cell_embeddings.shape) + + if len(cell_embeddings) < 2: + continue + + sorted_indices = np.argsort(cell_timepoints) + cell_embeddings = cell_embeddings[sorted_indices] + + # Compute differences between consecutive embeddings + differences = np.diff(cell_embeddings, axis=0) # Shape: (T-1, 768) + + if differences.shape[0] == 0: + continue + + # Compute RMS for this track + norms = np.linalg.norm(differences, axis=1) + if len(norms) > 0: + rms = np.sqrt(np.mean(norms**2)) + rms_values.append(rms) + + return rms_values + + +def plot_rms_histogram(rms_values, label, bins=30): + """ + Plot histogram of RMS values across tracks. + + Parameters: + rms_values : list + List of RMS values, one for each track. + label : str + Label for the dataset (used in the title). + bins : int, optional + Number of bins for the histogram. Default is 30. + + Returns: + None: Displays the histogram. + """ + plt.figure(figsize=(10, 6)) + plt.hist(rms_values, bins=bins, alpha=0.7, color="blue", edgecolor="black") + plt.title(f"Histogram of RMS Values Across Tracks ({label})", fontsize=16) + plt.xlabel("RMS of Time Derivative", fontsize=14) + plt.ylabel("Frequency", fontsize=14) + plt.grid(True) + plt.show() + +def plot_displacement(mean_displacement, std_displacement, label, metrics_no_track=None): + """ + Plot embedding displacement over time with mean and standard deviation. + + Parameters: + mean_displacement : dict + Mean displacement for each tau. + std_displacement : dict + Standard deviation of displacement for each tau. + label : str + Label for the dataset. + metrics_no_track : dict, optional + Metrics for the "Classical Contrastive (No Tracking)" dataset to compare against. + + Returns: + None: Displays the plot. + """ + plt.figure(figsize=(10, 6)) + taus = list(mean_displacement.keys()) + mean_values = list(mean_displacement.values()) + std_values = list(std_displacement.values()) + + plt.plot(taus, mean_values, marker="o", label=f"{label}", color="green") + plt.fill_between( + taus, + np.array(mean_values) - np.array(std_values), + np.array(mean_values) + np.array(std_values), + color="green", + alpha=0.3, + label=f"Std Dev ({label})", + ) + + if metrics_no_track: + mean_values_no_track = list(metrics_no_track["mean_displacement"].values()) + std_values_no_track = list(metrics_no_track["std_displacement"].values()) + + plt.plot( + taus, + mean_values_no_track, + marker="o", + label="Classical Contrastive (No Tracking)", + color="blue", + ) + plt.fill_between( + taus, + np.array(mean_values_no_track) - np.array(std_values_no_track), + np.array(mean_values_no_track) + np.array(std_values_no_track), + color="blue", + alpha=0.3, + label="Std Dev (No Tracking)", + ) + + plt.xlabel("Time Shift (τ)", fontsize=14) + plt.ylabel("Euclidean Distance", fontsize=14) + plt.title(f"Embedding Displacement Over Time ({label})", fontsize=16) + plt.grid(True) + plt.legend(fontsize=12) + plt.show() + +def plot_overlay_displacement(overlay_displacement_data): + """ + Plot embedding displacement over time for all datasets in one plot. + + Parameters: + overlay_displacement_data : dict + A dictionary containing mean displacement per tau for all datasets. + + Returns: + None: Displays the plot. + """ + plt.figure(figsize=(12, 8)) + for label, mean_displacement in overlay_displacement_data.items(): + taus = list(mean_displacement.keys()) + mean_values = list(mean_displacement.values()) + plt.plot(taus, mean_values, marker="o", label=label) + + plt.xlabel("Time Shift (τ)", fontsize=14) + plt.ylabel("Euclidean Distance", fontsize=14) + plt.title("Overlayed Embedding Displacement Over Time", fontsize=16) + plt.grid(True) + plt.legend(fontsize=12) + plt.show() + +# %% hist stats +def plot_boxplot_rms_across_models(datasets_rms): + """ + Plot a boxplot for the distribution of RMS values across models. + + Parameters: + datasets_rms : dict + A dictionary where keys are dataset names and values are lists of RMS values. + + Returns: + None: Displays the boxplot. + """ + plt.figure(figsize=(12, 6)) + labels = list(datasets_rms.keys()) + data = list(datasets_rms.values()) + print(labels) + print(data) + # Plot the boxplot + plt.boxplot(data, tick_labels=labels, patch_artist=True, showmeans=True) + + plt.title("Distribution of RMS of Rate of Change of Embedding Across Models", fontsize=16) + plt.ylabel("RMS of Time Derivative", fontsize=14) + plt.xticks(rotation=45, fontsize=12) + plt.grid(axis='y', linestyle='--', alpha=0.7) + plt.tight_layout() + plt.show() + + +def plot_histogram_absolute_differences(datasets_abs_diff): + """ + Plot histograms of absolute differences across embeddings for all models. + + Parameters: + datasets_abs_diff : dict + A dictionary where keys are dataset names and values are lists of absolute differences. + + Returns: + None: Displays the histograms. + """ + plt.figure(figsize=(12, 6)) + for label, abs_diff in datasets_abs_diff.items(): + plt.hist(abs_diff, bins=50, alpha=0.5, label=label, density=True) + + plt.title("Histograms of Absolute Differences Across Models", fontsize=16) + plt.xlabel("Absolute Difference", fontsize=14) + plt.ylabel("Density", fontsize=14) + plt.legend(fontsize=12) + plt.grid(alpha=0.7) + plt.tight_layout() + plt.show() + +# %% Paths to datasets feature_paths = { "7 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_7mins.zarr", "21 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_21mins.zarr", - "28 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_28mins.zarr", + "28 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_updated_28mins.zarr", "56 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_56mins.zarr", + "Cell Aware": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_cellaware.zarr", } no_track_path = "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_classical.zarr" -cell_aware_path = "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_cellaware.zarr" -# Parameters +# %% Process Datasets max_tau = 69 - metrics = {} +overlay_displacement_data = {} +datasets_rms = {} +datasets_abs_diff = {} + +# Process "No Tracking" dataset features_path_no_track = Path(no_track_path) embedding_dataset_no_track = read_embedding_dataset(features_path_no_track) -mean_displacement_no_track, std_displacement_no_track = compute_displacement_mean_std_full(embedding_dataset_no_track, max_tau) -dynamic_range_no_track, smoothness_no_track = compute_dynamic_smoothness_metrics(mean_displacement_no_track) - +mean_displacement_no_track, std_displacement_no_track = compute_displacement_mean_std_full( + embedding_dataset_no_track, max_tau=max_tau +) +dynamic_range_no_track = compute_dynamic_range(mean_displacement_no_track) metrics["No Tracking"] = { "dynamic_range": dynamic_range_no_track, - "smoothness": smoothness_no_track, "mean_displacement": mean_displacement_no_track, "std_displacement": std_displacement_no_track, } -print("Metrics for No Tracking:") -print(f" Dynamic Range: {dynamic_range_no_track}") -print(f" Smoothness: {smoothness_no_track}") +overlay_displacement_data["No Tracking"] = mean_displacement_no_track -features_path_cell_aware = Path(cell_aware_path) -embedding_dataset_cell_aware = read_embedding_dataset(features_path_cell_aware) +print("\nProcessing No Tracking dataset...") +print(f"Dynamic Range for No Tracking: {dynamic_range_no_track}") -mean_displacement_cell_aware, std_displacement_cell_aware = compute_displacement_mean_std_full(embedding_dataset_cell_aware, max_tau) -dynamic_range_cell_aware, smoothness_cell_aware = compute_dynamic_smoothness_metrics(mean_displacement_cell_aware) +plot_displacement( + mean_displacement_no_track, + std_displacement_no_track, + "No Tracking" +) -metrics["Cell Aware"] = { - "dynamic_range": dynamic_range_cell_aware, - "smoothness": smoothness_cell_aware, - "mean_displacement": mean_displacement_cell_aware, - "std_displacement": std_displacement_cell_aware, -} +rms_values_no_track = compute_rms_per_track(embedding_dataset_no_track) +datasets_rms["No Tracking"] = rms_values_no_track -print("Metrics for Cell Aware:") -print(f" Dynamic Range: {dynamic_range_cell_aware}") -print(f" Smoothness: {smoothness_cell_aware}") +print(f"Plotting histogram of RMS values for No Tracking dataset...") +plot_rms_histogram(rms_values_no_track, "No Tracking", bins=30) +# Compute absolute differences for "No Tracking" +abs_diff_no_track = np.concatenate([ + np.linalg.norm( + np.diff(embedding_dataset_no_track["features"].values[indices], axis=0), + axis=-1 + ) + for indices in np.split(np.arange(len(embedding_dataset_no_track["track_id"])), + np.where(np.diff(embedding_dataset_no_track["track_id"]) != 0)[0] + 1) +]) +datasets_abs_diff["No Tracking"] = abs_diff_no_track + +# Process other datasets for label, path in feature_paths.items(): + print(f"\nProcessing {label} dataset...") + features_path = Path(path) embedding_dataset = read_embedding_dataset(features_path) - - mean_displacement, std_displacement = compute_displacement_mean_std_full(embedding_dataset, max_tau) - dynamic_range, smoothness = compute_dynamic_smoothness_metrics(mean_displacement) - + + mean_displacement, std_displacement = compute_displacement_mean_std_full( + embedding_dataset, max_tau=max_tau + ) + dynamic_range = compute_dynamic_range(mean_displacement) metrics[label] = { "dynamic_range": dynamic_range, - "smoothness": smoothness, "mean_displacement": mean_displacement, "std_displacement": std_displacement, } - - plt.figure(figsize=(10, 6)) - taus = list(mean_displacement.keys()) - mean_values = list(mean_displacement.values()) - std_values = list(std_displacement.values()) - plt.plot(taus, mean_values, marker='o', label=f'{label}', color='green') - plt.fill_between(taus, - np.array(mean_values) - np.array(std_values), - np.array(mean_values) + np.array(std_values), - color='green', alpha=0.3, label=f'Std Dev ({label})') + overlay_displacement_data[label] = mean_displacement - mean_values_no_track = list(metrics["No Tracking"]["mean_displacement"].values()) - std_values_no_track = list(metrics["No Tracking"]["std_displacement"].values()) + print(f"Dynamic Range for {label}: {dynamic_range}") - plt.plot(taus, mean_values_no_track, marker='o', label='Classical Contrastive (No Tracking)', color='blue') - plt.fill_between(taus, - np.array(mean_values_no_track) - np.array(std_values_no_track), - np.array(mean_values_no_track) + np.array(std_values_no_track), - color='blue', alpha=0.3, label='Std Dev (No Tracking)') + plot_displacement( + mean_displacement, + std_displacement, + label, + metrics_no_track=metrics.get("No Tracking", None), + ) + + rms_values = compute_rms_per_track(embedding_dataset) + datasets_rms[label] = rms_values + + print(f"Plotting histogram of RMS values for {label}...") + plot_rms_histogram(rms_values, label, bins=30) + + abs_diff = np.concatenate([ + np.linalg.norm( + np.diff(embedding_dataset["features"].values[indices], axis=0), + axis=-1 + ) + for indices in np.split( + np.arange(len(embedding_dataset["track_id"])), + np.where(np.diff(embedding_dataset["track_id"]) != 0)[0] + 1 + ) + ]) + datasets_abs_diff[label] = abs_diff + +print("\nPlotting overlayed displacement for all datasets...") +plot_overlay_displacement(overlay_displacement_data) + +print("\nSummary of Dynamic Ranges:") +for label, metric in metrics.items(): + print(f"{label}: Dynamic Range = {metric['dynamic_range']}") + +print("\nPlotting RMS boxplot across models...") +plot_boxplot_rms_across_models(datasets_rms) + +print("\nPlotting histograms of absolute differences across models...") +plot_histogram_absolute_differences(datasets_abs_diff) - plt.xlabel('Time Shift (τ)') - plt.ylabel('Euclidean Distance') - plt.title(f'Embedding Displacement Over Time ({label})') - plt.grid(True) - plt.legend() - plt.show() - print(f"Metrics for {label}:") - print(f" Dynamic Range: {dynamic_range}") - print(f" Smoothness: {smoothness}") -# %% diff --git a/viscy/representation/evaluation/distance.py b/viscy/representation/evaluation/distance.py index 2fca9b79..7d0a7913 100644 --- a/viscy/representation/evaluation/distance.py +++ b/viscy/representation/evaluation/distance.py @@ -306,18 +306,20 @@ def compute_displacement_mean_std_full(embedding_dataset, max_tau=10): current_embedding = cell_embeddings[i] current_embedding = current_embedding / np.linalg.norm(current_embedding) - for tau in range(0, max_tau + 1): future_time = current_time + tau + future_index = np.where(cell_timepoints == future_time)[0] if len(future_index) >= 1: + future_embedding = cell_embeddings[future_index[0]] future_embedding = future_embedding / np.linalg.norm( future_embedding ) + distance = np.linalg.norm(current_embedding - future_embedding) displacement_per_tau[tau].append(distance) From 42cf7560915021d0cb7f5c1e7bc8b539f9c3155e Mon Sep 17 00:00:00 2001 From: Alishba Imran Date: Tue, 19 Nov 2024 18:58:59 -0800 Subject: [PATCH 05/44] fixed CI test cases --- viscy/representation/evaluation/distance.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/viscy/representation/evaluation/distance.py b/viscy/representation/evaluation/distance.py index 7d0a7913..5968fb03 100644 --- a/viscy/representation/evaluation/distance.py +++ b/viscy/representation/evaluation/distance.py @@ -309,7 +309,6 @@ def compute_displacement_mean_std_full(embedding_dataset, max_tau=10): for tau in range(0, max_tau + 1): future_time = current_time + tau - future_index = np.where(cell_timepoints == future_time)[0] if len(future_index) >= 1: @@ -319,7 +318,6 @@ def compute_displacement_mean_std_full(embedding_dataset, max_tau=10): future_embedding ) - distance = np.linalg.norm(current_embedding - future_embedding) displacement_per_tau[tau].append(distance) From 0a23fc4606cbe0c94198b154155cb887d6a916ad Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 13 Nov 2024 11:40:11 -0800 Subject: [PATCH 06/44] nexnt loss prototype --- viscy/representation/engine.py | 79 ++++++++++++++++++++++++++++++---- 1 file changed, 71 insertions(+), 8 deletions(-) diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index 42341c3c..2aec9853 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -1,5 +1,5 @@ import logging -from typing import Literal, Sequence, TypedDict +from typing import Literal, Sequence, TypedDict, Tuple import numpy as np import torch @@ -14,6 +14,61 @@ _logger = logging.getLogger("lightning.pytorch") +class NTXentLoss(torch.nn.Module): + """ + Normalized Temperature-scaled Cross Entropy Loss + + From Chen et.al, https://arxiv.org/abs/2002.05709 + """ + + def __init__(self, batch_size, temperature=0.5): + super(NTXentLoss, self).__init__() + self.batch_size = batch_size + self.temperature = temperature + self.mask = self._get_correlated_mask(batch_size) + self.criterion = torch.nn.CrossEntropyLoss(reduction="sum") + + def _get_correlated_mask(self, batch_size): + mask = torch.ones((2 * batch_size, 2 * batch_size), dtype=bool) + mask = mask.fill_diagonal_(0) + for i in range(batch_size): + mask[i, batch_size + i] = 0 + mask[batch_size + i, i] = 0 + return mask + + @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32) + def forward(self, zis, zjs): + """ + zis and zjs are the output projections from the two augmented views + + Here, we assume the two augmented views are the anchor and positive samples + """ + # Concatenate representations along the batch dimension + representations = torch.cat([zis, zjs], dim=0) + + # Cosine similarity + similarity_matrix = F.cosine_similarity( + representations.unsqueeze(1), representations.unsqueeze(0), dim=2 + ) + + # Temperature scaling + similarity_matrix = similarity_matrix / self.temperature + + # Find the valid pairs of positive samples + positive_samples = torch.cat( + [torch.arange(self.batch_size), torch.arange(self.batch_size)], dim=0 + ) + + # Mask out unwanted pairs + similarity_matrix = similarity_matrix[self.mask].view(2 * self.batch_size, -1) + + # Calculate NT-Xent Loss as cross-entropy + loss = self.criterion(similarity_matrix, positive_samples) + loss /= 2 * self.batch_size + + return loss + + class ContrastivePrediction(TypedDict): features: Tensor projections: Tensor @@ -27,7 +82,7 @@ def __init__( self, encoder: nn.Module | ContrastiveEncoder, loss_function: ( - nn.Module | nn.CosineEmbeddingLoss | nn.TripletMarginLoss + nn.Module | nn.CosineEmbeddingLoss | nn.TripletMarginLoss | NTXentLoss ) = nn.TripletMarginLoss(margin=0.5), lr: float = 1e-3, schedule: Literal["WarmupCosine", "Constant"] = "Constant", @@ -106,9 +161,13 @@ def training_step(self, batch: TripletSample, batch_idx: int) -> Tensor: anchor_projection = self(anchor_img) negative_projection = self(neg_img) positive_projection = self(pos_img) - loss = self.loss_function( - anchor_projection, positive_projection, negative_projection - ) + if isinstance(self.loss_function, NTXentLoss): + # Note: we assume the two augmented views are the anchor and positive samples + loss = self.loss_function(anchor_projection, positive_projection) + else: + loss = self.loss_function( + anchor_projection, positive_projection, negative_projection + ) self._log_metrics( loss, anchor_projection, @@ -137,9 +196,13 @@ def validation_step(self, batch: TripletSample, batch_idx: int) -> Tensor: anchor_projection = self(anchor) negative_projection = self(neg_img) positive_projection = self(pos_img) - loss = self.loss_function( - anchor_projection, positive_projection, negative_projection - ) + if isinstance(self.loss_function, NTXentLoss): + # Note: we assume the two augmented views are the anchor and positive samples + loss = self.loss_function(anchor_projection, positive_projection) + else: + loss = self.loss_function( + anchor_projection, positive_projection, negative_projection + ) self._log_metrics( loss, anchor_projection, positive_projection, negative_projection, "val" ) From f2d75cd2aa335ea912d30d1e62e49f64539ed6fe Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 13 Nov 2024 13:48:30 -0800 Subject: [PATCH 07/44] fix bug with z_scale_range in hcs datamodule. If the value is an int this does not work. --- viscy/data/hcs.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index 8831bdd7..c4087941 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -606,7 +606,8 @@ def _train_transform(self) -> list[Callable]: else: self.augmentations = [] if z_scale_range is not None: - if isinstance(z_scale_range, float): + if isinstance(z_scale_range, (float, int)): + z_scale_range = float(z_scale_range) z_scale_range = (-z_scale_range, z_scale_range) if z_scale_range[0] > 0 or z_scale_range[1] < 0: raise ValueError(f"Invalid scaling range: {z_scale_range}") From aa74efff5f62f7cb13113225bb6f3c519cb46875 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 13 Nov 2024 14:10:36 -0800 Subject: [PATCH 08/44] exclude the negative pair from dataloader and forward pass --- viscy/data/triplet.py | 27 ++++++++++------- viscy/representation/engine.py | 54 ++++++++++++++++++++++------------ 2 files changed, 52 insertions(+), 29 deletions(-) diff --git a/viscy/data/triplet.py b/viscy/data/triplet.py index b816b28c..b348a41e 100644 --- a/viscy/data/triplet.py +++ b/viscy/data/triplet.py @@ -62,6 +62,7 @@ def __init__( include_fov_names: list[str] | None = None, include_track_ids: list[int] | None = None, time_interval: Literal["any"] | int = "any", + return_negative: bool = True, ) -> None: """Dataset for triplet sampling of cells based on tracking. @@ -118,6 +119,7 @@ def __init__( self._specific_cells(self.tracks) if self.predict_cells else self.tracks ) self.valid_anchors = self._filter_anchors(self.tracks) + self.return_negative = return_negative def _filter_tracks(self, tracks_tables: list[pd.DataFrame]) -> pd.DataFrame: """Exclude tracks that are too close to the border or do not have the next time point. @@ -242,15 +244,16 @@ def __getitem__(self, index: int) -> TripletSample: patch=positive_patch, norm_meta=positive_norm, ) - negative_row = self._sample_negative(anchor_row) - negative_patch, negative_norm = self._slice_patch(negative_row) - if self.negative_transform: - negative_patch = _transform_channel_wise( - transform=self.negative_transform, - channel_names=self.channel_names, - patch=negative_patch, - norm_meta=negative_norm, - ) + if self.return_negative: + negative_row = self._sample_negative(anchor_row) + negative_patch, negative_norm = self._slice_patch(negative_row) + if self.negative_transform: + negative_patch = _transform_channel_wise( + transform=self.negative_transform, + channel_names=self.channel_names, + patch=negative_patch, + norm_meta=negative_norm, + ) if self.anchor_transform: anchor_patch = _transform_channel_wise( transform=self.anchor_transform, @@ -259,8 +262,12 @@ def __getitem__(self, index: int) -> TripletSample: norm_meta=anchor_norm, ) sample = {"anchor": anchor_patch} + if self.fit: - sample.update({"positive": positive_patch, "negative": negative_patch}) + if self.return_negative: + sample.update({"positive": positive_patch, "negative": negative_patch}) + else: + sample.update({"positive": positive_patch}) else: sample.update({"index": anchor_row[INDEX_COLUMNS].to_dict()}) return sample diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index 2aec9853..b8bcf142 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -57,7 +57,7 @@ def forward(self, zis, zjs): # Find the valid pairs of positive samples positive_samples = torch.cat( [torch.arange(self.batch_size), torch.arange(self.batch_size)], dim=0 - ) + ).to(similarity_matrix.device) # Mask out unwanted pairs similarity_matrix = similarity_matrix[self.mask].view(2 * self.batch_size, -1) @@ -120,7 +120,7 @@ def print_embedding_norms(self, anchor, positive, negative, phase): _logger.debug(f"{phase}/negative_norm: {negative_norm}") def _log_metrics( - self, loss, anchor, positive, negative, stage: Literal["train", "val"] + self, loss, anchor, positive, stage: Literal["train", "val"], negative=None ): self.log( f"loss/{stage}", @@ -131,17 +131,26 @@ def _log_metrics( logger=True, sync_dist=True, ) + cosine_sim_pos = F.cosine_similarity(anchor, positive, dim=1).mean() - cosine_sim_neg = F.cosine_similarity(anchor, negative, dim=1).mean() euclidean_dist_pos = F.pairwise_distance(anchor, positive).mean() - euclidean_dist_neg = F.pairwise_distance(anchor, negative).mean() + log_metric_dict = { + f"metrics/cosine_similarity_positive/{stage}": cosine_sim_pos, + f"metrics/euclidean_distance_positive/{stage}": euclidean_dist_pos, + } + + if negative is not None: + euclidean_dist_neg = F.pairwise_distance(anchor, negative).mean() + cosine_sim_neg = F.cosine_similarity(anchor, negative, dim=1).mean() + log_metric_dict[f"metrics/cosine_similarity_negative/{stage}"] = ( + cosine_sim_neg + ) + log_metric_dict[f"metrics/euclidean_distance_negative/{stage}"] = ( + euclidean_dist_neg + ) + self.log_dict( - { - f"metrics/cosine_similarity_positive/{stage}": cosine_sim_pos, - f"metrics/cosine_similarity_negative/{stage}": cosine_sim_neg, - f"metrics/euclidean_distance_positive/{stage}": euclidean_dist_pos, - f"metrics/euclidean_distance_negative/{stage}": euclidean_dist_neg, - }, + log_metric_dict, on_step=False, on_epoch=True, logger=True, @@ -157,24 +166,31 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): def training_step(self, batch: TripletSample, batch_idx: int) -> Tensor: anchor_img = batch["anchor"] pos_img = batch["positive"] - neg_img = batch["negative"] anchor_projection = self(anchor_img) - negative_projection = self(neg_img) positive_projection = self(pos_img) if isinstance(self.loss_function, NTXentLoss): # Note: we assume the two augmented views are the anchor and positive samples loss = self.loss_function(anchor_projection, positive_projection) + self._log_metrics( + loss=loss, + anchor=anchor_projection, + positive=positive_projection, + negative=None, + stage="train", + ) else: + neg_img = batch["negative"] + negative_projection = self(neg_img) loss = self.loss_function( anchor_projection, positive_projection, negative_projection ) - self._log_metrics( - loss, - anchor_projection, - positive_projection, - negative_projection, - stage="train", - ) + self._log_metrics( + loss=loss, + anchor=anchor_projection, + positive=positive_projection, + negative=negative_projection, + stage="train", + ) if batch_idx < self.log_batches_per_epoch: self.training_step_outputs.extend( detach_sample( From 771d33db20aae118981a6b914f60b7bbb3c952b4 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 13 Nov 2024 19:09:43 -0800 Subject: [PATCH 09/44] adding option using pytorch-metric-learning implementation and modifying previous to match same input args --- pyproject.toml | 1 + viscy/representation/engine.py | 114 ++++++++++++++++++++++----------- 2 files changed, 77 insertions(+), 38 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 55be6805..32f196c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "matplotlib>=3.9.0", "numpy", "xarray", + "pytorch-metric-learning>2.0.0" ] dynamic = ["version"] diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index b8bcf142..a67744f4 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -10,23 +10,27 @@ from viscy.data.typing import TrackingIndex, TripletSample from viscy.representation.contrastive import ContrastiveEncoder from viscy.utils.log_images import detach_sample, render_images +from pytorch_metric_learning.losses import SelfSupervisedLoss +from pytorch_metric_learning.losses import NTXentLoss as NTXentLoss_pml _logger = logging.getLogger("lightning.pytorch") -class NTXentLoss(torch.nn.Module): +class NTXentLoss_viscy(torch.nn.Module): """ Normalized Temperature-scaled Cross Entropy Loss From Chen et.al, https://arxiv.org/abs/2002.05709 """ - def __init__(self, batch_size, temperature=0.5): - super(NTXentLoss, self).__init__() - self.batch_size = batch_size + def __init__( + self, + temperature=0.5, + criterion=torch.nn.CrossEntropyLoss(reduction="sum"), + ): + super(NTXentLoss_viscy, self).__init__() self.temperature = temperature - self.mask = self._get_correlated_mask(batch_size) - self.criterion = torch.nn.CrossEntropyLoss(reduction="sum") + self.criterion = criterion def _get_correlated_mask(self, batch_size): mask = torch.ones((2 * batch_size, 2 * batch_size), dtype=bool) @@ -34,37 +38,36 @@ def _get_correlated_mask(self, batch_size): for i in range(batch_size): mask[i, batch_size + i] = 0 mask[batch_size + i, i] = 0 + _logger.info(f"mask: {mask}") return mask @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32) - def forward(self, zis, zjs): + def forward(self, embeddings, labels): """ - zis and zjs are the output projections from the two augmented views - + embeddings = [zis, zjs] + zis and zjs are the output projections from the two augmented views. Here, we assume the two augmented views are the anchor and positive samples """ - # Concatenate representations along the batch dimension - representations = torch.cat([zis, zjs], dim=0) + # Get the batch size from tensor + batch_size = embeddings.shape[0] // 2 + + zis, zjs = torch.split(embeddings, batch_size, dim=0) # Cosine similarity similarity_matrix = F.cosine_similarity( - representations.unsqueeze(1), representations.unsqueeze(0), dim=2 + embeddings.unsqueeze(1), embeddings.unsqueeze(0), dim=2 ) - # Temperature scaling similarity_matrix = similarity_matrix / self.temperature - # Find the valid pairs of positive samples - positive_samples = torch.cat( - [torch.arange(self.batch_size), torch.arange(self.batch_size)], dim=0 - ).to(similarity_matrix.device) + mask = self._get_correlated_mask(batch_size).to(similarity_matrix.device) # Mask out unwanted pairs - similarity_matrix = similarity_matrix[self.mask].view(2 * self.batch_size, -1) + similarity_matrix = similarity_matrix[mask].view(2 * batch_size, -1) # Calculate NT-Xent Loss as cross-entropy - loss = self.criterion(similarity_matrix, positive_samples) - loss /= 2 * self.batch_size + loss = self.criterion(similarity_matrix, labels) + loss /= 2 * batch_size return loss @@ -82,7 +85,11 @@ def __init__( self, encoder: nn.Module | ContrastiveEncoder, loss_function: ( - nn.Module | nn.CosineEmbeddingLoss | nn.TripletMarginLoss | NTXentLoss + nn.Module + | nn.CosineEmbeddingLoss + | nn.TripletMarginLoss + | NTXentLoss_pml + | NTXentLoss_viscy ) = nn.TripletMarginLoss(margin=0.5), lr: float = 1e-3, schedule: Literal["WarmupCosine", "Constant"] = "Constant", @@ -168,9 +175,14 @@ def training_step(self, batch: TripletSample, batch_idx: int) -> Tensor: pos_img = batch["positive"] anchor_projection = self(anchor_img) positive_projection = self(pos_img) - if isinstance(self.loss_function, NTXentLoss): + if isinstance(self.loss_function, (NTXentLoss_pml, NTXentLoss_viscy)): + indices = torch.arange( + 0, anchor_projection.size(0), device=anchor_projection.device + ) + labels = torch.cat((indices, indices)) # Note: we assume the two augmented views are the anchor and positive samples - loss = self.loss_function(anchor_projection, positive_projection) + embeddings = torch.cat((anchor_projection, positive_projection)) + loss = self.loss_function(embeddings, labels) self._log_metrics( loss=loss, anchor=anchor_projection, @@ -178,6 +190,10 @@ def training_step(self, batch: TripletSample, batch_idx: int) -> Tensor: negative=None, stage="train", ) + if batch_idx < self.log_batches_per_epoch: + self.training_step_outputs.extend( + detach_sample((anchor_img, pos_img), self.log_samples_per_batch) + ) else: neg_img = batch["negative"] negative_projection = self(neg_img) @@ -191,12 +207,12 @@ def training_step(self, batch: TripletSample, batch_idx: int) -> Tensor: negative=negative_projection, stage="train", ) - if batch_idx < self.log_batches_per_epoch: - self.training_step_outputs.extend( - detach_sample( - (anchor_img, pos_img, neg_img), self.log_samples_per_batch + if batch_idx < self.log_batches_per_epoch: + self.training_step_outputs.extend( + detach_sample( + (anchor_img, pos_img, neg_img), self.log_samples_per_batch + ) ) - ) return loss def on_train_epoch_end(self) -> None: @@ -208,24 +224,46 @@ def validation_step(self, batch: TripletSample, batch_idx: int) -> Tensor: """Validation step of the model.""" anchor = batch["anchor"] pos_img = batch["positive"] - neg_img = batch["negative"] anchor_projection = self(anchor) - negative_projection = self(neg_img) positive_projection = self(pos_img) - if isinstance(self.loss_function, NTXentLoss): + if isinstance(self.loss_function, (NTXentLoss_pml, NTXentLoss_viscy)): + indices = torch.arange( + 0, anchor_projection.size(0), device=anchor_projection.device + ) + labels = torch.cat((indices, indices)) # Note: we assume the two augmented views are the anchor and positive samples - loss = self.loss_function(anchor_projection, positive_projection) + embeddings = torch.cat((anchor_projection, positive_projection)) + loss = self.loss_function(embeddings, labels) + self._log_metrics( + loss=loss, + anchor=anchor_projection, + positive=positive_projection, + negative=None, + stage="val", + ) + if batch_idx < self.log_batches_per_epoch: + self.validation_step_outputs.extend( + detach_sample((anchor, pos_img), self.log_samples_per_batch) + ) else: + neg_img = batch["negative"] + negative_projection = self(neg_img) loss = self.loss_function( anchor_projection, positive_projection, negative_projection ) - self._log_metrics( - loss, anchor_projection, positive_projection, negative_projection, "val" - ) - if batch_idx < self.log_batches_per_epoch: - self.validation_step_outputs.extend( - detach_sample((anchor, pos_img, neg_img), self.log_samples_per_batch) + self._log_metrics( + loss=loss, + anchor=anchor_projection, + positive=positive_projection, + negative=negative_projection, + stage="val", ) + if batch_idx < self.log_batches_per_epoch: + self.validation_step_outputs.extend( + detach_sample( + (anchor, pos_img, neg_img), self.log_samples_per_batch + ) + ) return loss def on_validation_epoch_end(self) -> None: From 46974bcb2c5d7b7cd3c2894fd686744a13342c26 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 14 Nov 2024 14:34:05 -0800 Subject: [PATCH 10/44] removing our implementation of NTXentLoss and using pytorch metric --- viscy/representation/engine.py | 71 +++------------------------------- 1 file changed, 5 insertions(+), 66 deletions(-) diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index a67744f4..6ed16657 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -1,77 +1,20 @@ import logging -from typing import Literal, Sequence, TypedDict, Tuple +from typing import Literal, Sequence, Tuple, TypedDict import numpy as np import torch import torch.nn.functional as F from lightning.pytorch import LightningModule +from pytorch_metric_learning.losses import NTXentLoss from torch import Tensor, nn from viscy.data.typing import TrackingIndex, TripletSample from viscy.representation.contrastive import ContrastiveEncoder from viscy.utils.log_images import detach_sample, render_images -from pytorch_metric_learning.losses import SelfSupervisedLoss -from pytorch_metric_learning.losses import NTXentLoss as NTXentLoss_pml _logger = logging.getLogger("lightning.pytorch") -class NTXentLoss_viscy(torch.nn.Module): - """ - Normalized Temperature-scaled Cross Entropy Loss - - From Chen et.al, https://arxiv.org/abs/2002.05709 - """ - - def __init__( - self, - temperature=0.5, - criterion=torch.nn.CrossEntropyLoss(reduction="sum"), - ): - super(NTXentLoss_viscy, self).__init__() - self.temperature = temperature - self.criterion = criterion - - def _get_correlated_mask(self, batch_size): - mask = torch.ones((2 * batch_size, 2 * batch_size), dtype=bool) - mask = mask.fill_diagonal_(0) - for i in range(batch_size): - mask[i, batch_size + i] = 0 - mask[batch_size + i, i] = 0 - _logger.info(f"mask: {mask}") - return mask - - @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32) - def forward(self, embeddings, labels): - """ - embeddings = [zis, zjs] - zis and zjs are the output projections from the two augmented views. - Here, we assume the two augmented views are the anchor and positive samples - """ - # Get the batch size from tensor - batch_size = embeddings.shape[0] // 2 - - zis, zjs = torch.split(embeddings, batch_size, dim=0) - - # Cosine similarity - similarity_matrix = F.cosine_similarity( - embeddings.unsqueeze(1), embeddings.unsqueeze(0), dim=2 - ) - # Temperature scaling - similarity_matrix = similarity_matrix / self.temperature - - mask = self._get_correlated_mask(batch_size).to(similarity_matrix.device) - - # Mask out unwanted pairs - similarity_matrix = similarity_matrix[mask].view(2 * batch_size, -1) - - # Calculate NT-Xent Loss as cross-entropy - loss = self.criterion(similarity_matrix, labels) - loss /= 2 * batch_size - - return loss - - class ContrastivePrediction(TypedDict): features: Tensor projections: Tensor @@ -85,11 +28,7 @@ def __init__( self, encoder: nn.Module | ContrastiveEncoder, loss_function: ( - nn.Module - | nn.CosineEmbeddingLoss - | nn.TripletMarginLoss - | NTXentLoss_pml - | NTXentLoss_viscy + nn.Module | nn.CosineEmbeddingLoss | nn.TripletMarginLoss | NTXentLoss ) = nn.TripletMarginLoss(margin=0.5), lr: float = 1e-3, schedule: Literal["WarmupCosine", "Constant"] = "Constant", @@ -175,7 +114,7 @@ def training_step(self, batch: TripletSample, batch_idx: int) -> Tensor: pos_img = batch["positive"] anchor_projection = self(anchor_img) positive_projection = self(pos_img) - if isinstance(self.loss_function, (NTXentLoss_pml, NTXentLoss_viscy)): + if isinstance(self.loss_function, NTXentLoss): indices = torch.arange( 0, anchor_projection.size(0), device=anchor_projection.device ) @@ -226,7 +165,7 @@ def validation_step(self, batch: TripletSample, batch_idx: int) -> Tensor: pos_img = batch["positive"] anchor_projection = self(anchor) positive_projection = self(pos_img) - if isinstance(self.loss_function, (NTXentLoss_pml, NTXentLoss_viscy)): + if isinstance(self.loss_function, NTXentLoss): indices = torch.arange( 0, anchor_projection.size(0), device=anchor_projection.device ) From 1fde4ab9f623aa5e1a525c78a1b9ada448fdddb4 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 14 Nov 2024 14:46:05 -0800 Subject: [PATCH 11/44] ruff --- viscy/representation/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index 6ed16657..832fd62d 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -1,5 +1,5 @@ import logging -from typing import Literal, Sequence, Tuple, TypedDict +from typing import Literal, Sequence, TypedDict import numpy as np import torch From 3f8363d1819d3cca1609b39c1dfe610c2efc7774 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Tue, 19 Nov 2024 16:25:38 -0800 Subject: [PATCH 12/44] prototype for phate and umap plot --- pyproject.toml | 1 + viscy/representation/engine.py | 29 +++++++++++++++++++ .../evaluation/dimensionality_reduction.py | 26 +++++++++++++++++ 3 files changed, 56 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 32f196c8..6d97e8ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ visual = [ "plotly", "nbformat", "cmap", + "phate" ] dev = [ "viscy[metrics,examples,visual]", diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index 832fd62d..927b6169 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -7,6 +7,7 @@ from lightning.pytorch import LightningModule from pytorch_metric_learning.losses import NTXentLoss from torch import Tensor, nn +from umap import UMAP from viscy.data.typing import TrackingIndex, TripletSample from viscy.representation.contrastive import ContrastiveEncoder @@ -34,6 +35,7 @@ def __init__( schedule: Literal["WarmupCosine", "Constant"] = "Constant", log_batches_per_epoch: int = 8, log_samples_per_batch: int = 1, + log_embeddings: bool = False, example_input_array_shape: Sequence[int] = (1, 2, 15, 256, 256), ) -> None: super().__init__() @@ -46,6 +48,7 @@ def __init__( self.example_input_array = torch.rand(*example_input_array_shape) self.training_step_outputs = [] self.validation_step_outputs = [] + self.log_embeddings = log_embeddings def forward(self, x: Tensor) -> Tensor: "Only return projected embeddings for training and validation." @@ -109,6 +112,19 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): key, grid, self.current_epoch, dataformats="HWC" ) + def log_embedding_umap(self, embeddings: Tensor, tag: str): + _logger.debug(f"Computing UMAP for {tag} embeddings.") + umap = UMAP(n_components=2) + embeddings_np = embeddings.detach().cpu().numpy() + umap_embeddings = umap.fit_transform(embeddings_np) + + # Log UMAP embeddings to TensorBoard + self.logger.experiment.add_embedding( + umap_embeddings, + global_step=self.current_epoch, + tag=f"{tag}_umap", + ) + def training_step(self, batch: TripletSample, batch_idx: int) -> Tensor: anchor_img = batch["anchor"] pos_img = batch["positive"] @@ -157,6 +173,12 @@ def training_step(self, batch: TripletSample, batch_idx: int) -> Tensor: def on_train_epoch_end(self) -> None: super().on_train_epoch_end() self._log_samples("train_samples", self.training_step_outputs) + # Log UMAP embeddings for validation + if self.log_embeddings: + embeddings = torch.cat( + [output["embeddings"] for output in self.validation_step_outputs] + ) + self.log_embedding_umap(embeddings, tag="train") self.training_step_outputs = [] def validation_step(self, batch: TripletSample, batch_idx: int) -> Tensor: @@ -208,6 +230,13 @@ def validation_step(self, batch: TripletSample, batch_idx: int) -> Tensor: def on_validation_epoch_end(self) -> None: super().on_validation_epoch_end() self._log_samples("val_samples", self.validation_step_outputs) + # Log UMAP embeddings for training + if self.log_embeddings: + embeddings = torch.cat( + [output["embeddings"] for output in self.training_step_outputs] + ) + self.log_embedding_umap(embeddings, tag="val") + self.validation_step_outputs = [] def configure_optimizers(self): diff --git a/viscy/representation/evaluation/dimensionality_reduction.py b/viscy/representation/evaluation/dimensionality_reduction.py index 0a906bf4..55fff6f0 100644 --- a/viscy/representation/evaluation/dimensionality_reduction.py +++ b/viscy/representation/evaluation/dimensionality_reduction.py @@ -8,6 +8,32 @@ from xarray import Dataset +def compute_phate( + embedding_dataset, + n_components: int = None, + knn: int = 5, + decay: int = 40, + **phate_kwargs, +): + import phate + + features = embedding_dataset["features"] + projections = embedding_dataset["projections"] + + phate_operator = phate.PHATE( + n_components=n_components, knn=knn, decay=decay, **phate_kwargs + ) + phate_embedding = phate_operator.fit_transform(embedding_dataset["features"].values) + phate_projections = phate_operator.transform( + embedding_dataset["projections"].values + ) + phate_df = pd.DataFrame( + + ) + + return (phate_embedding, phate_projections, phate_df) + + def compute_pca(embedding_dataset, n_components=None, normalize_features=True): features = embedding_dataset["features"] projections = embedding_dataset["projections"] From 26ebe74bf5cd1be9a069f5fb3f5312b04415fb90 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 20 Nov 2024 15:24:27 -0800 Subject: [PATCH 13/44] - proofreading the calculations - removing unecessary calls to ALFI script - simplifying code to re-use functions --- .../evaluation/ALFI_displacement.py | 233 +++--------- viscy/representation/evaluation/distance.py | 335 ++++-------------- 2 files changed, 112 insertions(+), 456 deletions(-) diff --git a/applications/contrastive_phenotyping/evaluation/ALFI_displacement.py b/applications/contrastive_phenotyping/evaluation/ALFI_displacement.py index ba36e84d..595f283f 100644 --- a/applications/contrastive_phenotyping/evaluation/ALFI_displacement.py +++ b/applications/contrastive_phenotyping/evaluation/ALFI_displacement.py @@ -5,6 +5,9 @@ from viscy.representation.embedding_writer import read_embedding_dataset from viscy.representation.evaluation.distance import ( calculate_normalized_euclidean_distance_cell, + compute_displacement, + compute_dynamic_range, + compute_rms_per_track, ) from collections import defaultdict from tabulate import tabulate @@ -13,156 +16,10 @@ from sklearn.metrics.pairwise import cosine_similarity from collections import OrderedDict -# %% function +# %% function -def compute_displacement_mean_std_full(embedding_dataset, max_tau=10): - """ - Compute the mean and standard deviation of displacements between embeddings at time t and t + tau for all cells. - - Parameters: - embedding_dataset : xarray.Dataset - The dataset containing embeddings, timepoints, fov_name, and track_id. - max_tau : int, optional - The maximum tau value to compute displacements for. Default is 10. - - Returns: - tuple - - mean_displacement_per_tau (dict): Mean displacement for each tau. - - std_displacement_per_tau (dict): Standard deviation of displacements for each tau. - """ - fov_names = embedding_dataset["fov_name"].values - track_ids = embedding_dataset["track_id"].values - timepoints = embedding_dataset["t"].values - embeddings = embedding_dataset["features"].values - - cell_identifiers = np.array( - list(zip(fov_names, track_ids)), - dtype=[("fov_name", "O"), ("track_id", "int64")], - ) - - unique_cells = np.unique(cell_identifiers) - - displacement_per_tau = defaultdict(list) - - for cell in unique_cells: - fov_name = cell["fov_name"] - track_id = cell["track_id"] - - indices = np.where((fov_names == fov_name) & (track_ids == track_id))[0] - - cell_timepoints = timepoints[indices] - cell_embeddings = embeddings[indices] - - sorted_indices = np.argsort(cell_timepoints) - cell_timepoints = cell_timepoints[sorted_indices] - - cell_embeddings = cell_embeddings[sorted_indices] - - for i in range(len(cell_timepoints)): - current_time = cell_timepoints[i] - - current_embedding = cell_embeddings[i] - - current_embedding = current_embedding / np.linalg.norm(current_embedding) - - for tau in range(0, max_tau + 1): - future_time = current_time + tau - - - future_index = np.where(cell_timepoints == future_time)[0] - - if len(future_index) >= 1: - - future_embedding = cell_embeddings[future_index[0]] - future_embedding = future_embedding / np.linalg.norm( - future_embedding - ) - - distance = np.linalg.norm(current_embedding - future_embedding) - - displacement_per_tau[tau].append(distance) - - mean_displacement_per_tau = { - tau: np.mean(displacements) - for tau, displacements in displacement_per_tau.items() - } - std_displacement_per_tau = { - tau: np.std(displacements) - for tau, displacements in displacement_per_tau.items() - } - - return mean_displacement_per_tau, std_displacement_per_tau - - -def compute_dynamic_range(mean_displacement_per_tau): - """ - Compute the dynamic range as the difference between the maximum - and minimum mean displacement per τ. - - Parameters: - mean_displacement_per_tau: dict with τ as key and mean displacement as value - - Returns: - float: dynamic range (max displacement - min displacement) - """ - displacements = mean_displacement_per_tau.values() - return max(displacements) - min(displacements) - - -def compute_rms_per_track(embedding_dataset): - """ - Compute RMS of the time derivative of embeddings per track. - - Parameters: - embedding_dataset : xarray.Dataset - The dataset containing embeddings, timepoints, fov_name, and track_id. - - Returns: - list: A list of RMS values, one for each track. - """ - fov_names = embedding_dataset["fov_name"].values - track_ids = embedding_dataset["track_id"].values - timepoints = embedding_dataset["t"].values - embeddings = embedding_dataset["features"].values - - cell_identifiers = np.array( - list(zip(fov_names, track_ids)), - dtype=[("fov_name", "O"), ("track_id", "int64")], - ) - - unique_cells = np.unique(cell_identifiers) - - rms_values = [] - - for cell in unique_cells: - fov_name = cell["fov_name"] - track_id = cell["track_id"] - - indices = np.where((fov_names == fov_name) & (track_ids == track_id))[0] - - cell_timepoints = timepoints[indices] - cell_embeddings = embeddings[indices] - #print(cell_embeddings.shape) - - if len(cell_embeddings) < 2: - continue - - sorted_indices = np.argsort(cell_timepoints) - cell_embeddings = cell_embeddings[sorted_indices] - - # Compute differences between consecutive embeddings - differences = np.diff(cell_embeddings, axis=0) # Shape: (T-1, 768) - - if differences.shape[0] == 0: - continue - - # Compute RMS for this track - norms = np.linalg.norm(differences, axis=1) - if len(norms) > 0: - rms = np.sqrt(np.mean(norms**2)) - rms_values.append(rms) - - return rms_values +# Removed redundant compute_displacement_mean_std_full function +# Removed redundant compute_dynamic_range and compute_rms_per_track functions def plot_rms_histogram(rms_values, label, bins=30): @@ -188,7 +45,10 @@ def plot_rms_histogram(rms_values, label, bins=30): plt.grid(True) plt.show() -def plot_displacement(mean_displacement, std_displacement, label, metrics_no_track=None): + +def plot_displacement( + mean_displacement, std_displacement, label, metrics_no_track=None +): """ Plot embedding displacement over time with mean and standard deviation. @@ -247,6 +107,7 @@ def plot_displacement(mean_displacement, std_displacement, label, metrics_no_tra plt.legend(fontsize=12) plt.show() + def plot_overlay_displacement(overlay_displacement_data): """ Plot embedding displacement over time for all datasets in one plot. @@ -263,7 +124,7 @@ def plot_overlay_displacement(overlay_displacement_data): taus = list(mean_displacement.keys()) mean_values = list(mean_displacement.values()) plt.plot(taus, mean_values, marker="o", label=label) - + plt.xlabel("Time Shift (τ)", fontsize=14) plt.ylabel("Euclidean Distance", fontsize=14) plt.title("Overlayed Embedding Displacement Over Time", fontsize=16) @@ -271,7 +132,8 @@ def plot_overlay_displacement(overlay_displacement_data): plt.legend(fontsize=12) plt.show() -# %% hist stats + +# %% hist stats def plot_boxplot_rms_across_models(datasets_rms): """ Plot a boxplot for the distribution of RMS values across models. @@ -290,11 +152,13 @@ def plot_boxplot_rms_across_models(datasets_rms): print(data) # Plot the boxplot plt.boxplot(data, tick_labels=labels, patch_artist=True, showmeans=True) - - plt.title("Distribution of RMS of Rate of Change of Embedding Across Models", fontsize=16) + + plt.title( + "Distribution of RMS of Rate of Change of Embedding Across Models", fontsize=16 + ) plt.ylabel("RMS of Time Derivative", fontsize=14) plt.xticks(rotation=45, fontsize=12) - plt.grid(axis='y', linestyle='--', alpha=0.7) + plt.grid(axis="y", linestyle="--", alpha=0.7) plt.tight_layout() plt.show() @@ -313,7 +177,7 @@ def plot_histogram_absolute_differences(datasets_abs_diff): plt.figure(figsize=(12, 6)) for label, abs_diff in datasets_abs_diff.items(): plt.hist(abs_diff, bins=50, alpha=0.5, label=label, density=True) - + plt.title("Histograms of Absolute Differences Across Models", fontsize=16) plt.xlabel("Absolute Difference", fontsize=14) plt.ylabel("Density", fontsize=14) @@ -322,6 +186,7 @@ def plot_histogram_absolute_differences(datasets_abs_diff): plt.tight_layout() plt.show() + # %% Paths to datasets feature_paths = { "7 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_7mins.zarr", @@ -345,8 +210,8 @@ def plot_histogram_absolute_differences(datasets_abs_diff): features_path_no_track = Path(no_track_path) embedding_dataset_no_track = read_embedding_dataset(features_path_no_track) -mean_displacement_no_track, std_displacement_no_track = compute_displacement_mean_std_full( - embedding_dataset_no_track, max_tau=max_tau +mean_displacement_no_track, std_displacement_no_track = compute_displacement( + embedding_dataset_no_track, max_tau=max_tau, return_mean_std=True ) dynamic_range_no_track = compute_dynamic_range(mean_displacement_no_track) metrics["No Tracking"] = { @@ -360,11 +225,7 @@ def plot_histogram_absolute_differences(datasets_abs_diff): print("\nProcessing No Tracking dataset...") print(f"Dynamic Range for No Tracking: {dynamic_range_no_track}") -plot_displacement( - mean_displacement_no_track, - std_displacement_no_track, - "No Tracking" -) +plot_displacement(mean_displacement_no_track, std_displacement_no_track, "No Tracking") rms_values_no_track = compute_rms_per_track(embedding_dataset_no_track) datasets_rms["No Tracking"] = rms_values_no_track @@ -373,14 +234,18 @@ def plot_histogram_absolute_differences(datasets_abs_diff): plot_rms_histogram(rms_values_no_track, "No Tracking", bins=30) # Compute absolute differences for "No Tracking" -abs_diff_no_track = np.concatenate([ - np.linalg.norm( - np.diff(embedding_dataset_no_track["features"].values[indices], axis=0), - axis=-1 - ) - for indices in np.split(np.arange(len(embedding_dataset_no_track["track_id"])), - np.where(np.diff(embedding_dataset_no_track["track_id"]) != 0)[0] + 1) -]) +abs_diff_no_track = np.concatenate( + [ + np.linalg.norm( + np.diff(embedding_dataset_no_track["features"].values[indices], axis=0), + axis=-1, + ) + for indices in np.split( + np.arange(len(embedding_dataset_no_track["track_id"])), + np.where(np.diff(embedding_dataset_no_track["track_id"]) != 0)[0] + 1, + ) + ] +) datasets_abs_diff["No Tracking"] = abs_diff_no_track # Process other datasets @@ -390,8 +255,8 @@ def plot_histogram_absolute_differences(datasets_abs_diff): features_path = Path(path) embedding_dataset = read_embedding_dataset(features_path) - mean_displacement, std_displacement = compute_displacement_mean_std_full( - embedding_dataset, max_tau=max_tau + mean_displacement, std_displacement = compute_displacement( + embedding_dataset, max_tau=max_tau, return_mean_std=True ) dynamic_range = compute_dynamic_range(mean_displacement) metrics[label] = { @@ -417,16 +282,17 @@ def plot_histogram_absolute_differences(datasets_abs_diff): print(f"Plotting histogram of RMS values for {label}...") plot_rms_histogram(rms_values, label, bins=30) - abs_diff = np.concatenate([ - np.linalg.norm( - np.diff(embedding_dataset["features"].values[indices], axis=0), - axis=-1 - ) - for indices in np.split( - np.arange(len(embedding_dataset["track_id"])), - np.where(np.diff(embedding_dataset["track_id"]) != 0)[0] + 1 - ) - ]) + abs_diff = np.concatenate( + [ + np.linalg.norm( + np.diff(embedding_dataset["features"].values[indices], axis=0), axis=-1 + ) + for indices in np.split( + np.arange(len(embedding_dataset["track_id"])), + np.where(np.diff(embedding_dataset["track_id"]) != 0)[0] + 1, + ) + ] + ) datasets_abs_diff[label] = abs_diff print("\nPlotting overlayed displacement for all datasets...") @@ -443,3 +309,4 @@ def plot_histogram_absolute_differences(datasets_abs_diff): plot_histogram_absolute_differences(datasets_abs_diff) +# %% diff --git a/viscy/representation/evaluation/distance.py b/viscy/representation/evaluation/distance.py index 5968fb03..68a794b1 100644 --- a/viscy/representation/evaluation/distance.py +++ b/viscy/representation/evaluation/distance.py @@ -1,134 +1,22 @@ from collections import defaultdict - import numpy as np from sklearn.metrics.pairwise import cosine_similarity def calculate_cosine_similarity_cell(embedding_dataset, fov_name, track_id): - """ - Calculate cosine similarities for a specific cell over time. - - Parameters: - embedding_dataset : xarray.Dataset - The dataset containing embeddings, timepoints, fov_name, and track_id. - fov_name : str - Field of view name to identify the specific cell. - track_id : int - Track ID to identify the specific cell. - - Returns: - tuple - - time_points (array): Array of time points for the cell. - - cosine_similarities (list): Cosine similarities between the embedding - at the first time point and each subsequent time point. - """ - # Filter the dataset for the specific infected cell + """Extract embeddings and calculate cosine similarities for a specific cell""" filtered_data = embedding_dataset.where( (embedding_dataset["fov_name"] == fov_name) & (embedding_dataset["track_id"] == track_id), drop=True, ) - - # Extract the feature embeddings and time points features = filtered_data["features"].values # (sample, features) time_points = filtered_data["t"].values # (sample,) - - # Get the first time point's embedding first_time_point_embedding = features[0].reshape(1, -1) - - # Calculate cosine similarity between each time point and the first time point - cosine_similarities = [] - for i in range(len(time_points)): - similarity = cosine_similarity( - first_time_point_embedding, features[i].reshape(1, -1) - ) - cosine_similarities.append(similarity[0][0]) - - return time_points, cosine_similarities - - -def compute_displacement_mean_std( - embedding_dataset, max_tau=10, use_cosine=False, use_dissimilarity=False -): - """ - Compute the mean and standard deviation of displacements between embeddings at time t and t + tau for each tau. - - Parameters: - embedding_dataset : xarray.Dataset - The dataset containing embeddings, timepoints, fov_name, and track_id. - max_tau : int, optional - The maximum tau value to compute displacements for. Default is 10. - use_cosine : bool, optional - If True, compute cosine similarity instead of Euclidean distance. Default is False. - use_dissimilarity : bool, optional - If True and use_cosine is True, compute cosine dissimilarity (1 - similarity). - Default is False. - - Returns: - tuple - - mean_displacement_per_tau (dict): Mean displacement for each tau. - - std_displacement_per_tau (dict): Standard deviation of displacements for each tau. - """ - # Get the arrays of (fov_name, track_id, t, and embeddings) - fov_names = embedding_dataset["fov_name"].values - track_ids = embedding_dataset["track_id"].values - timepoints = embedding_dataset["t"].values - embeddings = embedding_dataset["features"].values - - # Dictionary to store displacements for each tau - displacement_per_tau = defaultdict(list) - - # Iterate over all entries in the dataset - for i in range(len(fov_names)): - fov_name = fov_names[i] - track_id = track_ids[i] - current_time = timepoints[i] - current_embedding = embeddings[i] - - # For each time point t, compute displacements for t + tau - for tau in range(1, max_tau + 1): - future_time = current_time + tau - - # Find if future_time exists for the same (fov_name, track_id) - matching_indices = np.where( - (fov_names == fov_name) - & (track_ids == track_id) - & (timepoints == future_time) - )[0] - - if len(matching_indices) == 1: - # Get the embedding at t + tau - future_embedding = embeddings[matching_indices[0]] - - if use_cosine: - # Compute cosine similarity - similarity = cosine_similarity( - current_embedding.reshape(1, -1), - future_embedding.reshape(1, -1), - )[0][0] - # Choose whether to use similarity or dissimilarity - if use_dissimilarity: - displacement = 1 - similarity # Cosine dissimilarity - else: - displacement = similarity # Cosine similarity - else: - # Compute the Euclidean distance, elementwise square on difference - displacement = np.sum((current_embedding - future_embedding) ** 2) - - # Store the displacement for the given tau - displacement_per_tau[tau].append(displacement) - - # Compute mean and std displacement for each tau by averaging the displacements - mean_displacement_per_tau = { - tau: np.mean(displacements) - for tau, displacements in displacement_per_tau.items() - } - std_displacement_per_tau = { - tau: np.std(displacements) - for tau, displacements in displacement_per_tau.items() - } - - return mean_displacement_per_tau, std_displacement_per_tau + cosine_similarities = cosine_similarity( + first_time_point_embedding, features + ).flatten() + return time_points, cosine_similarities.tolist() def compute_displacement( @@ -137,55 +25,30 @@ def compute_displacement( use_cosine=False, use_dissimilarity=False, use_umap=False, + return_mean_std=False, ): - """ - Compute the displacements between embeddings at time t and t + tau for each tau. - - Parameters: - embedding_dataset : xarray.Dataset - The dataset containing embeddings, timepoints, fov_name, and track_id. - max_tau : int, optional - The maximum tau value to compute displacements for. Default is 10. - use_cosine : bool, optional - If True, compute cosine similarity instead of Euclidean distance. Default is False. - use_dissimilarity : bool, optional - If True and use_cosine is True, compute cosine dissimilarity (1 - similarity). - Default is False. - use_umap : bool, optional - If True, use UMAP embeddings instead of feature embeddings. Default is False. - - Returns: - dict - A dictionary where the key is tau and the value is a list of displacements - for all cells at that tau. - """ - # Get the arrays of (fov_name, track_id, t, and embeddings) + """Compute the norm of differences between embeddings at t and t + tau""" fov_names = embedding_dataset["fov_name"].values track_ids = embedding_dataset["track_id"].values timepoints = embedding_dataset["t"].values if use_umap: - umap1 = embedding_dataset["UMAP1"].values - umap2 = embedding_dataset["UMAP2"].values - embeddings = np.vstack((umap1, umap2)).T + embeddings = np.vstack( + (embedding_dataset["UMAP1"].values, embedding_dataset["UMAP2"].values) + ).T else: embeddings = embedding_dataset["features"].values - # Dictionary to store displacements for each tau displacement_per_tau = defaultdict(list) - # Iterate over all entries in the dataset for i in range(len(fov_names)): fov_name = fov_names[i] track_id = track_ids[i] current_time = timepoints[i] - current_embedding = embeddings[i] + current_embedding = embeddings[i].reshape(1, -1) - # For each time point t, compute displacements for t + tau for tau in range(1, max_tau + 1): future_time = current_time + tau - - # Find if future_time exists for the same (fov_name, track_id) matching_indices = np.where( (fov_names == fov_name) & (track_ids == track_id) @@ -193,86 +56,57 @@ def compute_displacement( )[0] if len(matching_indices) == 1: - # Get the embedding at t + tau - future_embedding = embeddings[matching_indices[0]] + future_embedding = embeddings[matching_indices[0]].reshape(1, -1) if use_cosine: - # Compute cosine similarity - similarity = cosine_similarity( - current_embedding.reshape(1, -1), - future_embedding.reshape(1, -1), - )[0][0] - # Choose whether to use similarity or dissimilarity - if use_dissimilarity: - displacement = 1 - similarity # Cosine dissimilarity - else: - displacement = similarity # Cosine similarity + similarity = cosine_similarity(current_embedding, future_embedding)[ + 0 + ][0] + displacement = 1 - similarity if use_dissimilarity else similarity else: - # Compute the Euclidean distance, elementwise square on difference displacement = np.sum((current_embedding - future_embedding) ** 2) - # Store the displacement for the given tau displacement_per_tau[tau].append(displacement) + if return_mean_std: + mean_displacement_per_tau = { + tau: np.mean(displacements) + for tau, displacements in displacement_per_tau.items() + } + std_displacement_per_tau = { + tau: np.std(displacements) + for tau, displacements in displacement_per_tau.items() + } + return mean_displacement_per_tau, std_displacement_per_tau + return displacement_per_tau -def calculate_normalized_euclidean_distance_cell(embedding_dataset, fov_name, track_id): +def compute_dynamic_range(mean_displacement_per_tau): """ - Calculate the normalized Euclidean distance between the embedding at the first time point and each subsequent time point for a specific cell. + Compute the dynamic range as the difference between the maximum + and minimum mean displacement per τ. Parameters: - embedding_dataset : xarray.Dataset - The dataset containing embeddings, timepoints, fov_name, and track_id. - fov_name : str - Field of view name to identify the specific cell. - track_id : int - Track ID to identify the specific cell. + mean_displacement_per_tau: dict with τ as key and mean displacement as value Returns: - tuple - - time_points (array): Array of time points for the cell. - - euclidean_distances (list): Normalized Euclidean distances between the embedding - at the first time point and each subsequent time point. + float: dynamic range (max displacement - min displacement) """ - filtered_data = embedding_dataset.where( - (embedding_dataset["fov_name"] == fov_name) - & (embedding_dataset["track_id"] == track_id), - drop=True, - ) + displacements = list(mean_displacement_per_tau.values()) + return max(displacements) - min(displacements) - features = filtered_data["features"].values # (sample, features) - time_points = filtered_data["t"].values # (sample,) - - normalized_features = features / np.linalg.norm(features, axis=1, keepdims=True) - - # Get the first time point's normalized embedding - first_time_point_embedding = normalized_features[0].reshape(1, -1) - - euclidean_distances = [] - for i in range(len(time_points)): - distance = np.linalg.norm( - first_time_point_embedding - normalized_features[i].reshape(1, -1) - ) - euclidean_distances.append(distance) - - return time_points, euclidean_distances - -def compute_displacement_mean_std_full(embedding_dataset, max_tau=10): +def compute_rms_per_track(embedding_dataset): """ - Compute the mean and standard deviation of displacements between embeddings at time t and t + tau for all cells. + Compute RMS of the time derivative of embeddings per track. Parameters: embedding_dataset : xarray.Dataset The dataset containing embeddings, timepoints, fov_name, and track_id. - max_tau : int, optional - The maximum tau value to compute displacements for. Default is 10. Returns: - tuple - - mean_displacement_per_tau (dict): Mean displacement for each tau. - - std_displacement_per_tau (dict): Standard deviation of displacements for each tau. + list: A list of RMS values, one for each track. """ fov_names = embedding_dataset["fov_name"].values track_ids = embedding_dataset["track_id"].values @@ -283,90 +117,45 @@ def compute_displacement_mean_std_full(embedding_dataset, max_tau=10): list(zip(fov_names, track_ids)), dtype=[("fov_name", "O"), ("track_id", "int64")], ) - unique_cells = np.unique(cell_identifiers) - displacement_per_tau = defaultdict(list) + rms_values = [] for cell in unique_cells: fov_name = cell["fov_name"] track_id = cell["track_id"] - indices = np.where((fov_names == fov_name) & (track_ids == track_id))[0] - cell_timepoints = timepoints[indices] cell_embeddings = embeddings[indices] + if len(cell_embeddings) < 2: + continue + sorted_indices = np.argsort(cell_timepoints) - cell_timepoints = cell_timepoints[sorted_indices] cell_embeddings = cell_embeddings[sorted_indices] + differences = np.diff(cell_embeddings, axis=0) - for i in range(len(cell_timepoints)): - current_time = cell_timepoints[i] - current_embedding = cell_embeddings[i] - - current_embedding = current_embedding / np.linalg.norm(current_embedding) - for tau in range(0, max_tau + 1): - future_time = current_time + tau + if differences.shape[0] == 0: + continue - future_index = np.where(cell_timepoints == future_time)[0] + norms = np.linalg.norm(differences, axis=1) + rms = np.sqrt(np.mean(norms**2)) + rms_values.append(rms) - if len(future_index) >= 1: + return rms_values - future_embedding = cell_embeddings[future_index[0]] - future_embedding = future_embedding / np.linalg.norm( - future_embedding - ) - distance = np.linalg.norm(current_embedding - future_embedding) - - displacement_per_tau[tau].append(distance) - - mean_displacement_per_tau = { - tau: np.mean(displacements) - for tau, displacements in displacement_per_tau.items() - } - std_displacement_per_tau = { - tau: np.std(displacements) - for tau, displacements in displacement_per_tau.items() - } - - return mean_displacement_per_tau, std_displacement_per_tau - - -# Function to compute metrics for dynamic range and smoothness -def compute_dynamic_smoothness_metrics(mean_displacement_per_tau): - """ - Compute dynamic range and smoothness metrics for displacement curves. - - Parameters: - mean_displacement_per_tau: dict with tau as key and mean displacement as value - - Returns: - tuple: (dynamic_range, smoothness) - - dynamic_range: max displacement - min displacement - - smoothness: RMS of second differences of normalized curve - """ - taus = np.array(sorted(mean_displacement_per_tau.keys())) - displacements = np.array([mean_displacement_per_tau[tau] for tau in taus]) - - dynamic_range = np.max(displacements) - np.min(displacements) - - if np.max(displacements) != np.min(displacements): - displacements_normalized = (displacements - np.min(displacements)) / ( - np.max(displacements) - np.min(displacements) - ) - else: - displacements_normalized = displacements - np.min( - displacements - ) # Handle constant case - - first_diff = np.diff(displacements_normalized) - - second_diff = np.diff(first_diff) - - # Compute RMS of second differences as smoothness metric - # Lower values indicate smoother curves - smoothness = np.sqrt(np.mean(second_diff**2)) - - return dynamic_range, smoothness +def calculate_normalized_euclidean_distance_cell(embedding_dataset, fov_name, track_id): + filtered_data = embedding_dataset.where( + (embedding_dataset["fov_name"] == fov_name) + & (embedding_dataset["track_id"] == track_id), + drop=True, + ) + features = filtered_data["features"].values # (sample, features) + time_points = filtered_data["t"].values # (sample,) + normalized_features = features / np.linalg.norm(features, axis=1, keepdims=True) + first_time_point_embedding = normalized_features[0].reshape(1, -1) + euclidean_distances = np.linalg.norm( + first_time_point_embedding - normalized_features, axis=1 + ) + return time_points, euclidean_distances.tolist() From 7a3b03666ee49c84b67be0ad33ed5d48a1bd00b3 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 20 Nov 2024 16:50:43 -0800 Subject: [PATCH 14/44] methods to rank nearest neighbors in embeddings --- viscy/representation/evaluation/clustering.py | 75 +++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/viscy/representation/evaluation/clustering.py b/viscy/representation/evaluation/clustering.py index d87d3968..43480c70 100644 --- a/viscy/representation/evaluation/clustering.py +++ b/viscy/representation/evaluation/clustering.py @@ -1,5 +1,8 @@ """Methods for evaluating clustering performance.""" +import numpy as np +from numpy.typing import ArrayLike, NDArray +from scipy.spatial.distance import cdist from sklearn.cluster import DBSCAN from sklearn.metrics import ( accuracy_score, @@ -30,6 +33,78 @@ def knn_accuracy(embeddings, annotations, k=5): return accuracy +def cross_dissimilarity(features: ArrayLike, metric: str = "cosine") -> NDArray: + """Dissimilarity/distance between each pair of samples in the features. + + Parameters + ---------- + features : ArrayLike + Feature matrix (n_samples, n_features) + metric : str, optional + Metric type, by default "cosine" (cosine dissimilarity) + + Returns + ------- + NDArray + Dissimilarity square matrix (n_samples, n_samples) + """ + return cdist(features, features, metric=metric) + + +def rank_nearest_neighbors( + cross_dissimilarity: NDArray, normalize: bool = True +) -> NDArray: + """Rank each sample by (dis)similarity to all other samples. + + Parameters + ---------- + cross_dissimilarity : NDArray + Dissimilarity square matrix (n_samples, n_samples) + normalize : bool, optional + Normalize the rank matrix by sample size, by default True + If normalized, self (diagonal) will be at fraction 0, + and the farthest sample will be at fraction 1. + + Returns + ------- + NDArray + Rank matrix (n_samples, n_samples) + Ranking is done on axis=1 + """ + rankings = np.argsort(np.argsort(cross_dissimilarity, axis=1), axis=1) + if normalize: + rankings = rankings.astype(np.float64) / (rankings.shape[1] - 1) + return rankings + + +def select_block(distances: NDArray, index: NDArray) -> NDArray: + """Select with the same indexes along both dimensions for a square matrix.""" + return distances[index][:, index] + + +def compare_time_offset( + single_track_distances: NDArray, time_offset: int = 1 +) -> NDArray: + """Extract the nearest neighbor distances/rankings + of the next sample compared to each sample. + + Parameters + ---------- + single_track_distances : NDArray + Distances or rankings of a single track (n_samples, n_samples) + If the matrix is not symmetric (e.g. is rankings), + it should measured along dimension 1 + sample_offset : int, optional + Offset from the diagonal, by default 1 (the next sample in time) + + Returns + ------- + NDArray + Distances/rankings vector (n_samples - time_offset,) + """ + return single_track_distances.diagonal(offset=-time_offset) + + def dbscan_clustering(embeddings, eps=0.5, min_samples=5): """ Apply DBSCAN clustering to the embeddings. From 3992d9dbefbd4f8f6b0545aaeb4d9af40cfcf413 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 20 Nov 2024 16:51:06 -0800 Subject: [PATCH 15/44] example script to plot state change of a single track --- .../evaluation/time_decay_knn.py | 95 +++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 applications/contrastive_phenotyping/evaluation/time_decay_knn.py diff --git a/applications/contrastive_phenotyping/evaluation/time_decay_knn.py b/applications/contrastive_phenotyping/evaluation/time_decay_knn.py new file mode 100644 index 00000000..0e914ad9 --- /dev/null +++ b/applications/contrastive_phenotyping/evaluation/time_decay_knn.py @@ -0,0 +1,95 @@ +# %% +from pathlib import Path + +import matplotlib.pyplot as plt +import seaborn as sns + +from viscy.representation.embedding_writer import read_embedding_dataset +from viscy.representation.evaluation.clustering import ( + compare_time_offset, + cross_dissimilarity, + rank_nearest_neighbors, + select_block, +) + +# %% +prediction_path = Path( + "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_7mins.zarr" +) + +embeddings = read_embedding_dataset(prediction_path) +features = embeddings["features"] + +# %% +cross_dist = cross_dissimilarity(features.values, metric="cosine") +rank_fractions = rank_nearest_neighbors(cross_dist, normalize=True) + +# %% +# select a single track in a single fov +fov = "/0/0/0" +fov_idx = (features["fov_name"] == fov).values + +track_id = 1 +track_idx = (features["track_id"] == track_id).values + +fov_and_track_idx = fov_idx & track_idx + +single_track_dissimilarity = select_block(cross_dist, fov_and_track_idx) +single_track_rank_fraction = select_block(rank_fractions, fov_and_track_idx) + +piece_wise_dissimilarity = compare_time_offset( + single_track_dissimilarity, time_offset=1 +) +piece_wise_rank_difference = compare_time_offset( + single_track_rank_fraction, time_offset=1 +) + +# %% +f = plt.figure(figsize=(8, 12)) +f.suptitle(f"Track {track_id} in FOV {fov}") +subfigs = f.subfigures(2, 1, height_ratios=[1, 2]) + +umap = subfigs[0].subplots(1, 1) +single_cell_features = features.sel(fov_name=fov, track_id=track_id).sortby("t") +sns.lineplot( + x=single_cell_features["UMAP1"], + y=single_cell_features["UMAP2"], + ax=umap, + color="k", + alpha=0.5, +) +sns.scatterplot( + x=features["UMAP1"], y=features["UMAP2"], ax=umap, color="k", s=100, alpha=0.01 +) +sns.scatterplot( + x=single_cell_features["UMAP1"], + y=single_cell_features["UMAP2"], + hue=single_cell_features["t"], + ax=umap, + palette="RdYlGn", +) + +f1 = subfigs[1] +ax = f1.subplots(2, 2) + +sns.heatmap(single_track_dissimilarity, ax=ax[0, 0], square=True) +ax[0, 0].set_title("Cosine dissimilarity") +ax[0, 0].set_xlabel("Frame") +ax[0, 0].set_ylabel("Frame") + +sns.heatmap(single_track_rank_fraction, ax=ax[0, 1], square=True) +ax[0, 1].set_title("Column-wise normalized neighborhood distance") +ax[0, 1].set_xlabel("Frame") +ax[0, 1].set_ylabel("Frame") + +sns.lineplot(piece_wise_dissimilarity, ax=ax[1, 0]) +ax[1, 0].set_title("$1 - \cos{(t_i, t_{i+1})}$") +ax[1, 0].set_xlabel("Frame") +ax[1, 0].set_ylabel("Cosine dissimilarity") + +sns.lineplot(piece_wise_rank_difference, ax=ax[1, 1]) +ax[1, 1].set_title("Nearest neighbor fraction difference") +ax[1, 1].set_xlabel("Frame") +ax[1, 1].set_ylabel("Rank fraction") + +# %% From cae8290396c85b0dad322f1e85c07d5f55ccb894 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 20 Nov 2024 17:08:57 -0800 Subject: [PATCH 16/44] test using scaled features --- .../contrastive_phenotyping/evaluation/time_decay_knn.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/applications/contrastive_phenotyping/evaluation/time_decay_knn.py b/applications/contrastive_phenotyping/evaluation/time_decay_knn.py index 0e914ad9..66e5f5c4 100644 --- a/applications/contrastive_phenotyping/evaluation/time_decay_knn.py +++ b/applications/contrastive_phenotyping/evaluation/time_decay_knn.py @@ -3,6 +3,7 @@ import matplotlib.pyplot as plt import seaborn as sns +from sklearn.preprocessing import StandardScaler from viscy.representation.embedding_writer import read_embedding_dataset from viscy.representation.evaluation.clustering import ( @@ -21,7 +22,10 @@ features = embeddings["features"] # %% -cross_dist = cross_dissimilarity(features.values, metric="cosine") +scaled_features = StandardScaler().fit_transform(features.values) + +# %% +cross_dist = cross_dissimilarity(scaled_features, metric="cosine") rank_fractions = rank_nearest_neighbors(cross_dist, normalize=True) # %% From 30f42469d254c4feab3394af434bfb18f2284bf6 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 20 Nov 2024 17:08:06 -0800 Subject: [PATCH 17/44] phate embeddings --- .../evaluation/dimensionality_reduction.py | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/viscy/representation/evaluation/dimensionality_reduction.py b/viscy/representation/evaluation/dimensionality_reduction.py index 55fff6f0..268fa81b 100644 --- a/viscy/representation/evaluation/dimensionality_reduction.py +++ b/viscy/representation/evaluation/dimensionality_reduction.py @@ -17,21 +17,27 @@ def compute_phate( ): import phate - features = embedding_dataset["features"] - projections = embedding_dataset["projections"] - phate_operator = phate.PHATE( n_components=n_components, knn=knn, decay=decay, **phate_kwargs ) phate_embedding = phate_operator.fit_transform(embedding_dataset["features"].values) - phate_projections = phate_operator.transform( - embedding_dataset["projections"].values - ) + phate_df = pd.DataFrame( - + { + "id": embedding_dataset["id"].values, + "track_id": embedding_dataset["track_id"].values, + "t": embedding_dataset["t"].values, + "fov_name": embedding_dataset["fov_name"].values, + "PHATE1": phate_embedding[:, 0], + "PHATE2": phate_embedding[:, 1], + "PHATE3": phate_embedding[:, 2], + "PHATE4": phate_embedding[:, 3], + "PHATE5": phate_embedding[:, 4], + "PHATE6": phate_embedding[:, 5], + } ) - return (phate_embedding, phate_projections, phate_df) + return (phate_embedding, phate_df) def compute_pca(embedding_dataset, n_components=None, normalize_features=True): From 356c15d2b6655334c3d470aa9908086afbea5e2a Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 20 Nov 2024 17:33:35 -0800 Subject: [PATCH 18/44] removing dataframe from the compute_phate adding docstring --- .../evaluation/dimensionality_reduction.py | 40 +++++++++++-------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/viscy/representation/evaluation/dimensionality_reduction.py b/viscy/representation/evaluation/dimensionality_reduction.py index 268fa81b..34a67d20 100644 --- a/viscy/representation/evaluation/dimensionality_reduction.py +++ b/viscy/representation/evaluation/dimensionality_reduction.py @@ -15,6 +15,29 @@ def compute_phate( decay: int = 40, **phate_kwargs, ): + """ + Compute PHATE embeddings for features + + + Parameters + ---------- + embedding_dataset : xarray.Dataset + The dataset containing embeddings, timepoints, fov_name, and track_id. + n_components : int, optional + Number of dimensions in the PHATE embedding, by default None + knn : int, optional + Number of nearest neighbors to use in the KNN graph, by default 5 + decay : int, optional + Decay parameter for the Markov operator, by default 40 + phate_kwargs : dict, optional + Additional keyword arguments for PHATE, by default None + + Returns + ------- + tuple[NDArray, pd.DataFrame] + PHATE embeddings and DataFrame with PHATE + + """ import phate phate_operator = phate.PHATE( @@ -22,22 +45,7 @@ def compute_phate( ) phate_embedding = phate_operator.fit_transform(embedding_dataset["features"].values) - phate_df = pd.DataFrame( - { - "id": embedding_dataset["id"].values, - "track_id": embedding_dataset["track_id"].values, - "t": embedding_dataset["t"].values, - "fov_name": embedding_dataset["fov_name"].values, - "PHATE1": phate_embedding[:, 0], - "PHATE2": phate_embedding[:, 1], - "PHATE3": phate_embedding[:, 2], - "PHATE4": phate_embedding[:, 3], - "PHATE5": phate_embedding[:, 4], - "PHATE6": phate_embedding[:, 5], - } - ) - - return (phate_embedding, phate_df) + return phate_embedding def compute_pca(embedding_dataset, n_components=None, normalize_features=True): From 57abb195495d652a6abde180d6b5b5c5b4f4a8b6 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 20 Nov 2024 18:44:29 -0800 Subject: [PATCH 19/44] adding phate to the prediction writer and moving it as dependency. --- pyproject.toml | 2 +- viscy/representation/embedding_writer.py | 4 ++++ .../evaluation/dimensionality_reduction.py | 14 +++++++------- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6d97e8ba..e40db067 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "matplotlib>=3.9.0", "numpy", "xarray", + "phate", "pytorch-metric-learning>2.0.0" ] dynamic = ["version"] @@ -50,7 +51,6 @@ visual = [ "plotly", "nbformat", "cmap", - "phate" ] dev = [ "viscy[metrics,examples,visual]", diff --git a/viscy/representation/embedding_writer.py b/viscy/representation/embedding_writer.py index 4f324047..4d55f07f 100644 --- a/viscy/representation/embedding_writer.py +++ b/viscy/representation/embedding_writer.py @@ -12,6 +12,7 @@ from viscy.data.triplet import INDEX_COLUMNS from viscy.representation.engine import ContrastivePrediction from viscy.representation.evaluation.dimensionality_reduction import ( + compute_phate, _fit_transform_umap, ) @@ -93,8 +94,11 @@ def write_on_epoch_end( ultrack_indices = pd.concat([pd.DataFrame(p["index"]) for p in predictions]) _logger.info(f"Computing UMAP embeddings for {len(features)} samples.") _, umap = _fit_transform_umap(features, n_components=2, normalize=True) + _, phate = compute_phate(features) ultrack_indices["UMAP1"] = umap[:, 0] ultrack_indices["UMAP2"] = umap[:, 1] + ultrack_indices["PHATE1"] = phate[:, 0] + ultrack_indices["PHATE2"] = phate[:, 1] index = pd.MultiIndex.from_frame(ultrack_indices) dataset = Dataset( { diff --git a/viscy/representation/evaluation/dimensionality_reduction.py b/viscy/representation/evaluation/dimensionality_reduction.py index 34a67d20..1c49fc21 100644 --- a/viscy/representation/evaluation/dimensionality_reduction.py +++ b/viscy/representation/evaluation/dimensionality_reduction.py @@ -6,6 +6,7 @@ from sklearn.decomposition import PCA from sklearn.preprocessing import StandardScaler from xarray import Dataset +import phate def compute_phate( @@ -14,7 +15,7 @@ def compute_phate( knn: int = 5, decay: int = 40, **phate_kwargs, -): +) -> tuple[phate.PHATE, NDArray]: """ Compute PHATE embeddings for features @@ -34,18 +35,17 @@ def compute_phate( Returns ------- - tuple[NDArray, pd.DataFrame] - PHATE embeddings and DataFrame with PHATE - + phate.PHATE, NDArray + PHATE model and PHATE embeddings """ import phate - phate_operator = phate.PHATE( + phate_model = phate.PHATE( n_components=n_components, knn=knn, decay=decay, **phate_kwargs ) - phate_embedding = phate_operator.fit_transform(embedding_dataset["features"].values) + phate_embedding = phate_model.fit_transform(embedding_dataset["features"].values) - return phate_embedding + return phate_model, phate_embedding def compute_pca(embedding_dataset, n_components=None, normalize_features=True): From 19fdcf28b6b24be8eec15c8ff54adcd419171603 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 20 Nov 2024 18:47:28 -0800 Subject: [PATCH 20/44] changing the phate defaults in the prediction writer. --- viscy/representation/embedding_writer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/viscy/representation/embedding_writer.py b/viscy/representation/embedding_writer.py index 4d55f07f..4bf8c4f5 100644 --- a/viscy/representation/embedding_writer.py +++ b/viscy/representation/embedding_writer.py @@ -94,7 +94,7 @@ def write_on_epoch_end( ultrack_indices = pd.concat([pd.DataFrame(p["index"]) for p in predictions]) _logger.info(f"Computing UMAP embeddings for {len(features)} samples.") _, umap = _fit_transform_umap(features, n_components=2, normalize=True) - _, phate = compute_phate(features) + _, phate = compute_phate(features, n_components=2, knn=5, decay=40, n_jobs=-1) ultrack_indices["UMAP1"] = umap[:, 0] ultrack_indices["UMAP2"] = umap[:, 1] ultrack_indices["PHATE1"] = phate[:, 0] From 1ee8af15d5ae38e7c2a3c15590696418371d73f1 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 20 Nov 2024 18:50:14 -0800 Subject: [PATCH 21/44] ruff --- viscy/representation/embedding_writer.py | 2 +- viscy/representation/evaluation/dimensionality_reduction.py | 2 +- viscy/representation/evaluation/distance.py | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/viscy/representation/embedding_writer.py b/viscy/representation/embedding_writer.py index 4bf8c4f5..1021f2ba 100644 --- a/viscy/representation/embedding_writer.py +++ b/viscy/representation/embedding_writer.py @@ -12,8 +12,8 @@ from viscy.data.triplet import INDEX_COLUMNS from viscy.representation.engine import ContrastivePrediction from viscy.representation.evaluation.dimensionality_reduction import ( - compute_phate, _fit_transform_umap, + compute_phate, ) __all__ = ["read_embedding_dataset", "EmbeddingWriter"] diff --git a/viscy/representation/evaluation/dimensionality_reduction.py b/viscy/representation/evaluation/dimensionality_reduction.py index 1c49fc21..6689d83f 100644 --- a/viscy/representation/evaluation/dimensionality_reduction.py +++ b/viscy/representation/evaluation/dimensionality_reduction.py @@ -1,12 +1,12 @@ """PCA and UMAP dimensionality reduction.""" import pandas as pd +import phate import umap from numpy.typing import NDArray from sklearn.decomposition import PCA from sklearn.preprocessing import StandardScaler from xarray import Dataset -import phate def compute_phate( diff --git a/viscy/representation/evaluation/distance.py b/viscy/representation/evaluation/distance.py index 68a794b1..9a1c72ef 100644 --- a/viscy/representation/evaluation/distance.py +++ b/viscy/representation/evaluation/distance.py @@ -1,4 +1,5 @@ from collections import defaultdict + import numpy as np from sklearn.metrics.pairwise import cosine_similarity From 14b73685efbb1b9d22ceb4c80ba69a0d6e2adf1d Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 20 Nov 2024 22:49:01 -0800 Subject: [PATCH 22/44] fixing bug in phate in predict writer --- viscy/representation/embedding_writer.py | 6 ++++-- .../evaluation/dimensionality_reduction.py | 15 +++++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/viscy/representation/embedding_writer.py b/viscy/representation/embedding_writer.py index 1021f2ba..8ed6b74d 100644 --- a/viscy/representation/embedding_writer.py +++ b/viscy/representation/embedding_writer.py @@ -13,7 +13,7 @@ from viscy.representation.engine import ContrastivePrediction from viscy.representation.evaluation.dimensionality_reduction import ( _fit_transform_umap, - compute_phate, + _fit_transform_phate, ) __all__ = ["read_embedding_dataset", "EmbeddingWriter"] @@ -94,7 +94,9 @@ def write_on_epoch_end( ultrack_indices = pd.concat([pd.DataFrame(p["index"]) for p in predictions]) _logger.info(f"Computing UMAP embeddings for {len(features)} samples.") _, umap = _fit_transform_umap(features, n_components=2, normalize=True) - _, phate = compute_phate(features, n_components=2, knn=5, decay=40, n_jobs=-1) + _, phate = _fit_transform_phate( + features, n_components=2, knn=5, decay=40, n_jobs=-1 + ) ultrack_indices["UMAP1"] = umap[:, 0] ultrack_indices["UMAP2"] = umap[:, 1] ultrack_indices["PHATE1"] = phate[:, 0] diff --git a/viscy/representation/evaluation/dimensionality_reduction.py b/viscy/representation/evaluation/dimensionality_reduction.py index 6689d83f..721bb8ed 100644 --- a/viscy/representation/evaluation/dimensionality_reduction.py +++ b/viscy/representation/evaluation/dimensionality_reduction.py @@ -99,6 +99,21 @@ def _fit_transform_umap( return umap_model, umap_embedding +def _fit_transform_phate( + embeddings: NDArray, + n_components: int = 2, + knn: int = 5, + decay: int = 40, + n_jobs: int = -1, +) -> tuple[phate.PHATE, NDArray]: + """Fit PHATE model and transform embeddings.""" + phate_model = phate.PHATE( + n_components=n_components, knn=knn, decay=decay, n_jobs=n_jobs + ) + phate_embedding = phate_model.fit_transform(embeddings) + return phate_model, phate_embedding + + def compute_umap( embedding_dataset: Dataset, normalize_features: bool = True ) -> tuple[umap.UMAP, umap.UMAP, pd.DataFrame]: From db653a543f871d0b4223559bc6fdb011ec829ec9 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 21 Nov 2024 18:39:54 -0800 Subject: [PATCH 23/44] adding code for measuring the smoothness --- .../cosine_dissimilarity_dataset.py | 136 ++++++++++++++++++ pyproject.toml | 2 +- 2 files changed, 137 insertions(+), 1 deletion(-) create mode 100644 applications/contrastive_phenotyping/evaluation/cosine_dissimilarity_dataset.py diff --git a/applications/contrastive_phenotyping/evaluation/cosine_dissimilarity_dataset.py b/applications/contrastive_phenotyping/evaluation/cosine_dissimilarity_dataset.py new file mode 100644 index 00000000..4cb0a6bc --- /dev/null +++ b/applications/contrastive_phenotyping/evaluation/cosine_dissimilarity_dataset.py @@ -0,0 +1,136 @@ +# %% +from pathlib import Path + +import matplotlib.pyplot as plt +import seaborn as sns +from sklearn.preprocessing import StandardScaler + +from viscy.representation.embedding_writer import read_embedding_dataset +from viscy.representation.evaluation.clustering import ( + compare_time_offset, + cross_dissimilarity, + rank_nearest_neighbors, + select_block, +) +import numpy as np + +plt.style.use("../evaluation/figure.mplstyle") + + +# %% +prediction_path = Path( + "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_7mins.zarr" +) + +embeddings = read_embedding_dataset(prediction_path) +features = embeddings["features"] + +# %% +scaled_features = StandardScaler().fit_transform(features.values) + +# %% +cross_dist = cross_dissimilarity(scaled_features, metric="cosine") +rank_fractions = rank_nearest_neighbors(cross_dist, normalize=True) + +plt.imshow(cross_dist) + +# %% +""" +Computing the smoothness and dynamic range +- Get the off diagonal per block and compute the mode +- The blocks are not square, so we need to get the off diagonal elements +- Get the 1 and 99 percentile of the off diagonal per block +""" +features_df = features["sample"].to_dataframe().reset_index(drop=True) + +piece_wise_dissimilarity_per_track = [] +piece_wise_rank_difference_per_track = [] +for name, subdata in features_df.groupby(["fov_name", "track_id"]): + if len(subdata) > 1: + single_track_dissimilarity = select_block(cross_dist, subdata.index.values) + single_track_rank_fraction = select_block(rank_fractions, subdata.index.values) + piece_wise_dissimilarity = compare_time_offset( + single_track_dissimilarity, time_offset=1 + ) + piece_wise_rank_difference = compare_time_offset( + single_track_rank_fraction, time_offset=1 + ) + piece_wise_dissimilarity_per_track.append(piece_wise_dissimilarity) + piece_wise_rank_difference_per_track.append(piece_wise_rank_difference) + +# Get the median/mode of the off diagonal elements +median_piece_wise_dissimilarity = [ + np.median(track) for track in piece_wise_dissimilarity_per_track +] +p99_piece_wise_dissimilarity = [ + np.percentile(track, 99) for track in piece_wise_dissimilarity_per_track +] +p1_percentile_piece_wise_dissimilarity = [ + np.percentile(track, 1) for track in piece_wise_dissimilarity_per_track +] + +# Plot the histogram of the median dissimilarity +plt.figure() +plt.title("Adjacent Frame Median Dissimilarity per Track") +sns.histplot(median_piece_wise_dissimilarity, bins=30, kde=True) +plt.xlabel("Cosine Dissimilarity") +plt.ylabel("Frequency") +plt.tight_layout() +plt.show() + +# Plot the histogram of the 1 percentile dissimilarity +plt.figure() +plt.title("Adjacent Frame 1 Percentile Dissimilarity per Track") +sns.histplot(p1_percentile_piece_wise_dissimilarity, bins=30, kde=True) +plt.xlabel("Cosine Dissimilarity") +plt.ylabel("Frequency") +plt.tight_layout() +plt.show() + +# Plot the histogram of the 99 percentile dissimilarity +plt.figure() +plt.title("Adjacent Frame 99 Percentile Dissimilarity per Track") +sns.histplot(p99_piece_wise_dissimilarity, bins=30, kde=True) +plt.xlabel("Cosine Dissimilarity") +plt.ylabel("Frequency") +plt.tight_layout() +plt.show() + +# %% +# Random sampling values in the values in the dissimilarity matrix +# Random Sampling +n_samples = 2000 +sampled_values = [] +for _ in range(n_samples): + i, j = np.random.randint(0, len(cross_dist), size=2) + sampled_values.append(cross_dist[i, j]) + + +# Plot the histogram of the sampled values +plt.figure() +plt.title("Random Sampling of Dissimilarity Values") +sns.histplot(sampled_values, bins=30, kde=True, stat="density") +plt.xlabel("Cosine Dissimilarity") +plt.ylabel("Density") +plt.tight_layout() +plt.show() + +# %% +# Plot the median and the random sampling in one plot each with different colors +plt.figure() +sns.histplot( + median_piece_wise_dissimilarity, + bins=30, + kde=True, + color="cyan", + alpha=0.5, + stat="density", +) +sns.histplot(sampled_values, bins=30, kde=True, color="red", alpha=0.5, stat="density") +plt.xlabel("Cosine Dissimilarity") +plt.ylabel("Density") +plt.tight_layout() +plt.legend(["Adjacent Frame", "Random Sample"]) +plt.show() + +# %% diff --git a/pyproject.toml b/pyproject.toml index e40db067..e4eb3a07 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,6 @@ dependencies = [ "matplotlib>=3.9.0", "numpy", "xarray", - "phate", "pytorch-metric-learning>2.0.0" ] dynamic = ["version"] @@ -41,6 +40,7 @@ metrics = [ "ptflops>=0.7", "umap-learn", "captum>=0.7.0", + "phate", ] examples = ["napari", "jupyter", "jupytext"] visual = [ From cf6dbe7c8c7b8d64681b828d7e215215aa5ff177 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 21 Nov 2024 18:53:52 -0800 Subject: [PATCH 24/44] cleanup to run on triplet and ntxent --- .../cosine_dissimilarity_dataset.py | 252 ++++++++++-------- 1 file changed, 142 insertions(+), 110 deletions(-) diff --git a/applications/contrastive_phenotyping/evaluation/cosine_dissimilarity_dataset.py b/applications/contrastive_phenotyping/evaluation/cosine_dissimilarity_dataset.py index 4cb0a6bc..00356e32 100644 --- a/applications/contrastive_phenotyping/evaluation/cosine_dissimilarity_dataset.py +++ b/applications/contrastive_phenotyping/evaluation/cosine_dissimilarity_dataset.py @@ -4,6 +4,7 @@ import matplotlib.pyplot as plt import seaborn as sns from sklearn.preprocessing import StandardScaler +from numpy.typing import NDArray from viscy.representation.embedding_writer import read_embedding_dataset from viscy.representation.evaluation.clustering import ( @@ -13,124 +14,155 @@ select_block, ) import numpy as np +from tqdm import tqdm +import pandas as pd plt.style.use("../evaluation/figure.mplstyle") +# plotting +VERBOSE = False + + +def compute_piece_wise_dissimilarity( + features_df: pd.DataFrame, cross_dist: NDArray, rank_fractions: NDArray +): + """ + Computing the smoothness and dynamic range + - Get the off diagonal per block and compute the mode + - The blocks are not square, so we need to get the off diagonal elements + - Get the 1 and 99 percentile of the off diagonal per block + """ + piece_wise_dissimilarity_per_track = [] + piece_wise_rank_difference_per_track = [] + for name, subdata in features_df.groupby(["fov_name", "track_id"]): + if len(subdata) > 1: + single_track_dissimilarity = select_block(cross_dist, subdata.index.values) + single_track_rank_fraction = select_block( + rank_fractions, subdata.index.values + ) + piece_wise_dissimilarity = compare_time_offset( + single_track_dissimilarity, time_offset=1 + ) + piece_wise_rank_difference = compare_time_offset( + single_track_rank_fraction, time_offset=1 + ) + piece_wise_dissimilarity_per_track.append(piece_wise_dissimilarity) + piece_wise_rank_difference_per_track.append(piece_wise_rank_difference) + return piece_wise_dissimilarity_per_track, piece_wise_rank_difference_per_track + + +def plot_histogram( + data, title, xlabel, ylabel, color="blue", alpha=0.5, stat="frequency" +): + plt.figure() + plt.title(title) + sns.histplot(data, bins=30, kde=True, color=color, alpha=alpha, stat=stat) + plt.xlabel(xlabel) + plt.ylabel(ylabel) + plt.tight_layout() + plt.show() -# %% -prediction_path = Path( - "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_7mins.zarr" -) - -embeddings = read_embedding_dataset(prediction_path) -features = embeddings["features"] - -# %% -scaled_features = StandardScaler().fit_transform(features.values) - -# %% -cross_dist = cross_dissimilarity(scaled_features, metric="cosine") -rank_fractions = rank_nearest_neighbors(cross_dist, normalize=True) - -plt.imshow(cross_dist) # %% -""" -Computing the smoothness and dynamic range -- Get the off diagonal per block and compute the mode -- The blocks are not square, so we need to get the off diagonal elements -- Get the 1 and 99 percentile of the off diagonal per block -""" -features_df = features["sample"].to_dataframe().reset_index(drop=True) - -piece_wise_dissimilarity_per_track = [] -piece_wise_rank_difference_per_track = [] -for name, subdata in features_df.groupby(["fov_name", "track_id"]): - if len(subdata) > 1: - single_track_dissimilarity = select_block(cross_dist, subdata.index.values) - single_track_rank_fraction = select_block(rank_fractions, subdata.index.values) - piece_wise_dissimilarity = compare_time_offset( - single_track_dissimilarity, time_offset=1 +prediction_path_1 = Path( + "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/rev6_NTXent_sensorPhase_infection/2chan_160patch_98ckpt_rev6_1.zarr" +) +prediction_path_2 = Path( + "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/rev5_sensorPhase_infection/2chan_160patch_97ckpt_rev5_1.zarr" +) +for prediction_path in tqdm([prediction_path_1, prediction_path_2]): + + # Read the dataset + embeddings = read_embedding_dataset(prediction_path) + features = embeddings["features"] + + scaled_features = StandardScaler().fit_transform(features.values) + # COmpute the cosine dissimilarity + cross_dist = cross_dissimilarity(scaled_features, metric="cosine") + rank_fractions = rank_nearest_neighbors(cross_dist, normalize=True) + + plt.figure() + plt.imshow(cross_dist) + plt.show() + + # Compute piece-wise dissimilarity and rank difference + features_df = features["sample"].to_dataframe().reset_index(drop=True) + piece_wise_dissimilarity_per_track, piece_wise_rank_difference_per_track = ( + compute_piece_wise_dissimilarity(features_df, cross_dist, rank_fractions) + ) + + # Get the median/mode of the off diagonal elements + median_piece_wise_dissimilarity = [ + np.median(track) for track in piece_wise_dissimilarity_per_track + ] + p99_piece_wise_dissimilarity = [ + np.percentile(track, 99) for track in piece_wise_dissimilarity_per_track + ] + p1_percentile_piece_wise_dissimilarity = [ + np.percentile(track, 1) for track in piece_wise_dissimilarity_per_track + ] + + # Random sampling values in the dissimilarity matrix + n_samples = 2000 + sampled_values = [ + cross_dist[ + np.random.randint(0, len(cross_dist)), np.random.randint(0, len(cross_dist)) + ] + for _ in range(n_samples) + ] + + if VERBOSE: + # Plot histograms + plot_histogram( + median_piece_wise_dissimilarity, + "Adjacent Frame Median Dissimilarity per Track", + "Cosine Dissimilarity", + "Frequency", ) - piece_wise_rank_difference = compare_time_offset( - single_track_rank_fraction, time_offset=1 + plot_histogram( + p1_percentile_piece_wise_dissimilarity, + "Adjacent Frame 1 Percentile Dissimilarity per Track", + "Cosine Dissimilarity", + "Frequency", + ) + plot_histogram( + p99_piece_wise_dissimilarity, + "Adjacent Frame 99 Percentile Dissimilarity per Track", + "Cosine Dissimilarity", + "Frequency", + ) + # Plot the histogram of the sampled values + plot_histogram( + sampled_values, + "Random Sampling of Dissimilarity Values", + "Cosine Dissimilarity", + "Density", + color="red", + alpha=0.5, + stat="density", ) - piece_wise_dissimilarity_per_track.append(piece_wise_dissimilarity) - piece_wise_rank_difference_per_track.append(piece_wise_rank_difference) - -# Get the median/mode of the off diagonal elements -median_piece_wise_dissimilarity = [ - np.median(track) for track in piece_wise_dissimilarity_per_track -] -p99_piece_wise_dissimilarity = [ - np.percentile(track, 99) for track in piece_wise_dissimilarity_per_track -] -p1_percentile_piece_wise_dissimilarity = [ - np.percentile(track, 1) for track in piece_wise_dissimilarity_per_track -] - -# Plot the histogram of the median dissimilarity -plt.figure() -plt.title("Adjacent Frame Median Dissimilarity per Track") -sns.histplot(median_piece_wise_dissimilarity, bins=30, kde=True) -plt.xlabel("Cosine Dissimilarity") -plt.ylabel("Frequency") -plt.tight_layout() -plt.show() - -# Plot the histogram of the 1 percentile dissimilarity -plt.figure() -plt.title("Adjacent Frame 1 Percentile Dissimilarity per Track") -sns.histplot(p1_percentile_piece_wise_dissimilarity, bins=30, kde=True) -plt.xlabel("Cosine Dissimilarity") -plt.ylabel("Frequency") -plt.tight_layout() -plt.show() - -# Plot the histogram of the 99 percentile dissimilarity -plt.figure() -plt.title("Adjacent Frame 99 Percentile Dissimilarity per Track") -sns.histplot(p99_piece_wise_dissimilarity, bins=30, kde=True) -plt.xlabel("Cosine Dissimilarity") -plt.ylabel("Frequency") -plt.tight_layout() -plt.show() - -# %% -# Random sampling values in the values in the dissimilarity matrix -# Random Sampling -n_samples = 2000 -sampled_values = [] -for _ in range(n_samples): - i, j = np.random.randint(0, len(cross_dist), size=2) - sampled_values.append(cross_dist[i, j]) - - -# Plot the histogram of the sampled values -plt.figure() -plt.title("Random Sampling of Dissimilarity Values") -sns.histplot(sampled_values, bins=30, kde=True, stat="density") -plt.xlabel("Cosine Dissimilarity") -plt.ylabel("Density") -plt.tight_layout() -plt.show() -# %% -# Plot the median and the random sampling in one plot each with different colors -plt.figure() -sns.histplot( - median_piece_wise_dissimilarity, - bins=30, - kde=True, - color="cyan", - alpha=0.5, - stat="density", -) -sns.histplot(sampled_values, bins=30, kde=True, color="red", alpha=0.5, stat="density") -plt.xlabel("Cosine Dissimilarity") -plt.ylabel("Density") -plt.tight_layout() -plt.legend(["Adjacent Frame", "Random Sample"]) -plt.show() + # Plot the median and the random sampling in one plot each with different colors + fig = plt.figure() + sns.histplot( + median_piece_wise_dissimilarity, + bins=30, + kde=True, + color="cyan", + alpha=0.5, + stat="density", + ) + sns.histplot( + sampled_values, bins=30, kde=True, color="red", alpha=0.5, stat="density" + ) + plt.xlabel("Cosine Dissimilarity") + plt.ylabel("Density") + plt.tight_layout() + plt.legend(["Adjacent Frame", "Random Sample"]) + plt.show() + fig.savefig( + f"./cosine_dissimilarity_smoothness_{prediction_path.stem}.pdf", + dpi=300, + ) # %% From a98e8825b2705e0518c5e02ae7075bfbd8efdcc4 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Fri, 22 Nov 2024 14:36:21 -0800 Subject: [PATCH 25/44] fix plots for smoothnes --- .../cosine_dissimilarity_dataset.py | 104 +++++++++++------- 1 file changed, 62 insertions(+), 42 deletions(-) diff --git a/applications/contrastive_phenotyping/evaluation/cosine_dissimilarity_dataset.py b/applications/contrastive_phenotyping/evaluation/cosine_dissimilarity_dataset.py index 00356e32..021eb432 100644 --- a/applications/contrastive_phenotyping/evaluation/cosine_dissimilarity_dataset.py +++ b/applications/contrastive_phenotyping/evaluation/cosine_dissimilarity_dataset.py @@ -17,6 +17,7 @@ from tqdm import tqdm import pandas as pd + plt.style.use("../evaluation/figure.mplstyle") # plotting @@ -36,10 +37,9 @@ def compute_piece_wise_dissimilarity( piece_wise_rank_difference_per_track = [] for name, subdata in features_df.groupby(["fov_name", "track_id"]): if len(subdata) > 1: - single_track_dissimilarity = select_block(cross_dist, subdata.index.values) - single_track_rank_fraction = select_block( - rank_fractions, subdata.index.values - ) + indices = subdata.index.values + single_track_dissimilarity = select_block(cross_dist, indices) + single_track_rank_fraction = select_block(rank_fractions, indices) piece_wise_dissimilarity = compare_time_offset( single_track_dissimilarity, time_offset=1 ) @@ -64,20 +64,25 @@ def plot_histogram( # %% +PATH_TO_GDRIVE_FIGUE = "/home/eduardo.hirata/mydata/gdrive/publications/learning_impacts_of_infection/fig_manuscript/rev2_ICLR_fig/" + prediction_path_1 = Path( - "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/rev6_NTXent_sensorPhase_infection/2chan_160patch_98ckpt_rev6_1.zarr" + "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/rev6_NTXent_sensorPhase_infection/2chan_160patch_98ckpt_rev6_2.zarr" ) prediction_path_2 = Path( - "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/rev5_sensorPhase_infection/2chan_160patch_97ckpt_rev5_1.zarr" + "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/rev5_sensorPhase_infection/2chan_160patch_97ckpt_rev5_2.zarr" ) -for prediction_path in tqdm([prediction_path_1, prediction_path_2]): + +for prediction_path, loss_name in tqdm( + [(prediction_path_1, "ntxent"), (prediction_path_2, "triplet")] +): # Read the dataset embeddings = read_embedding_dataset(prediction_path) features = embeddings["features"] scaled_features = StandardScaler().fit_transform(features.values) - # COmpute the cosine dissimilarity + # Compute the cosine dissimilarity cross_dist = cross_dissimilarity(scaled_features, metric="cosine") rank_fractions = rank_nearest_neighbors(cross_dist, normalize=True) @@ -91,43 +96,58 @@ def plot_histogram( compute_piece_wise_dissimilarity(features_df, cross_dist, rank_fractions) ) - # Get the median/mode of the off diagonal elements - median_piece_wise_dissimilarity = [ - np.median(track) for track in piece_wise_dissimilarity_per_track - ] - p99_piece_wise_dissimilarity = [ - np.percentile(track, 99) for track in piece_wise_dissimilarity_per_track - ] - p1_percentile_piece_wise_dissimilarity = [ - np.percentile(track, 1) for track in piece_wise_dissimilarity_per_track - ] + all_dissimilarity = np.concatenate(piece_wise_dissimilarity_per_track) + + # # Get the median/mode of the off diagonal elements + # median_piece_wise_dissimilarity = np.array( + # [np.median(track) for track in piece_wise_dissimilarity_per_track] + # ) + p99_piece_wise_dissimilarity = np.array( + [np.percentile(track, 99) for track in piece_wise_dissimilarity_per_track] + ) + p1_percentile_piece_wise_dissimilarity = np.array( + [np.percentile(track, 1) for track in piece_wise_dissimilarity_per_track] + ) # Random sampling values in the dissimilarity matrix - n_samples = 2000 - sampled_values = [ - cross_dist[ - np.random.randint(0, len(cross_dist)), np.random.randint(0, len(cross_dist)) - ] - for _ in range(n_samples) - ] + n_samples = 3000 + random_indices = np.random.randint(0, len(cross_dist), size=(n_samples, 2)) + sampled_values = cross_dist[random_indices[:, 0], random_indices[:, 1]] + + print(f"Dissimilarity Statistics for {prediction_path.stem}") + print(f"Mean: {np.mean(all_dissimilarity)}") + print(f"Std: {np.std(all_dissimilarity)}") + print(f"Median: {np.median(all_dissimilarity)}") + + print(f"Distance Statistics for random sampling") + print(f"Mean: {np.mean(sampled_values)}") + print(f"Std: {np.std(sampled_values)}") + print(f"Median: {np.median(sampled_values)}") if VERBOSE: # Plot histograms + # plot_histogram( + # median_piece_wise_dissimilarity, + # "Adjacent Frame Median Dissimilarity per Track", + # "Cosine Dissimilarity", + # "Frequency", + # ) + # plot_histogram( + # p1_percentile_piece_wise_dissimilarity, + # "Adjacent Frame 1 Percentile Dissimilarity per Track", + # "Cosine Dissimilarity", + # "Frequency", + # ) + # plot_histogram( + # p99_piece_wise_dissimilarity, + # "Adjacent Frame 99 Percentile Dissimilarity per Track", + # "Cosine Dissimilarity", + # "Frequency", + # ) + plot_histogram( - median_piece_wise_dissimilarity, - "Adjacent Frame Median Dissimilarity per Track", - "Cosine Dissimilarity", - "Frequency", - ) - plot_histogram( - p1_percentile_piece_wise_dissimilarity, - "Adjacent Frame 1 Percentile Dissimilarity per Track", - "Cosine Dissimilarity", - "Frequency", - ) - plot_histogram( - p99_piece_wise_dissimilarity, - "Adjacent Frame 99 Percentile Dissimilarity per Track", + piece_wise_dissimilarity_per_track, + "Adjacent Frame Dissimilarity per Track", "Cosine Dissimilarity", "Frequency", ) @@ -145,7 +165,7 @@ def plot_histogram( # Plot the median and the random sampling in one plot each with different colors fig = plt.figure() sns.histplot( - median_piece_wise_dissimilarity, + all_dissimilarity, bins=30, kde=True, color="cyan", @@ -161,8 +181,8 @@ def plot_histogram( plt.legend(["Adjacent Frame", "Random Sample"]) plt.show() fig.savefig( - f"./cosine_dissimilarity_smoothness_{prediction_path.stem}.pdf", - dpi=300, + f"{PATH_TO_GDRIVE_FIGUE}/cosine_dissimilarity_smoothness_{prediction_path.stem}_{loss_name}.pdf", + dpi=600, ) # %% From b75ddd1b0974b67d7d5f9c9d206be690b8659c89 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 13 Nov 2024 11:40:11 -0800 Subject: [PATCH 26/44] nexnt loss prototype --- viscy/representation/engine.py | 139 ++++++++++++++++++--------------- 1 file changed, 76 insertions(+), 63 deletions(-) diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index 927b6169..f1153f36 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -1,5 +1,5 @@ import logging -from typing import Literal, Sequence, TypedDict +from typing import Literal, Sequence, TypedDict, Tuple import numpy as np import torch @@ -16,6 +16,61 @@ _logger = logging.getLogger("lightning.pytorch") +class NTXentLoss(torch.nn.Module): + """ + Normalized Temperature-scaled Cross Entropy Loss + + From Chen et.al, https://arxiv.org/abs/2002.05709 + """ + + def __init__(self, batch_size, temperature=0.5): + super(NTXentLoss, self).__init__() + self.batch_size = batch_size + self.temperature = temperature + self.mask = self._get_correlated_mask(batch_size) + self.criterion = torch.nn.CrossEntropyLoss(reduction="sum") + + def _get_correlated_mask(self, batch_size): + mask = torch.ones((2 * batch_size, 2 * batch_size), dtype=bool) + mask = mask.fill_diagonal_(0) + for i in range(batch_size): + mask[i, batch_size + i] = 0 + mask[batch_size + i, i] = 0 + return mask + + @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32) + def forward(self, zis, zjs): + """ + zis and zjs are the output projections from the two augmented views + + Here, we assume the two augmented views are the anchor and positive samples + """ + # Concatenate representations along the batch dimension + representations = torch.cat([zis, zjs], dim=0) + + # Cosine similarity + similarity_matrix = F.cosine_similarity( + representations.unsqueeze(1), representations.unsqueeze(0), dim=2 + ) + + # Temperature scaling + similarity_matrix = similarity_matrix / self.temperature + + # Find the valid pairs of positive samples + positive_samples = torch.cat( + [torch.arange(self.batch_size), torch.arange(self.batch_size)], dim=0 + ) + + # Mask out unwanted pairs + similarity_matrix = similarity_matrix[self.mask].view(2 * self.batch_size, -1) + + # Calculate NT-Xent Loss as cross-entropy + loss = self.criterion(similarity_matrix, positive_samples) + loss /= 2 * self.batch_size + + return loss + + class ContrastivePrediction(TypedDict): features: Tensor projections: Tensor @@ -131,43 +186,25 @@ def training_step(self, batch: TripletSample, batch_idx: int) -> Tensor: anchor_projection = self(anchor_img) positive_projection = self(pos_img) if isinstance(self.loss_function, NTXentLoss): - indices = torch.arange( - 0, anchor_projection.size(0), device=anchor_projection.device - ) - labels = torch.cat((indices, indices)) # Note: we assume the two augmented views are the anchor and positive samples - embeddings = torch.cat((anchor_projection, positive_projection)) - loss = self.loss_function(embeddings, labels) - self._log_metrics( - loss=loss, - anchor=anchor_projection, - positive=positive_projection, - negative=None, - stage="train", - ) - if batch_idx < self.log_batches_per_epoch: - self.training_step_outputs.extend( - detach_sample((anchor_img, pos_img), self.log_samples_per_batch) - ) + loss = self.loss_function(anchor_projection, positive_projection) else: - neg_img = batch["negative"] - negative_projection = self(neg_img) loss = self.loss_function( anchor_projection, positive_projection, negative_projection ) - self._log_metrics( - loss=loss, - anchor=anchor_projection, - positive=positive_projection, - negative=negative_projection, - stage="train", - ) - if batch_idx < self.log_batches_per_epoch: - self.training_step_outputs.extend( - detach_sample( - (anchor_img, pos_img, neg_img), self.log_samples_per_batch - ) + self._log_metrics( + loss, + anchor_projection, + positive_projection, + negative_projection, + stage="train", + ) + if batch_idx < self.log_batches_per_epoch: + self.training_step_outputs.extend( + detach_sample( + (anchor_img, pos_img, neg_img), self.log_samples_per_batch ) + ) return loss def on_train_epoch_end(self) -> None: @@ -188,43 +225,19 @@ def validation_step(self, batch: TripletSample, batch_idx: int) -> Tensor: anchor_projection = self(anchor) positive_projection = self(pos_img) if isinstance(self.loss_function, NTXentLoss): - indices = torch.arange( - 0, anchor_projection.size(0), device=anchor_projection.device - ) - labels = torch.cat((indices, indices)) # Note: we assume the two augmented views are the anchor and positive samples - embeddings = torch.cat((anchor_projection, positive_projection)) - loss = self.loss_function(embeddings, labels) - self._log_metrics( - loss=loss, - anchor=anchor_projection, - positive=positive_projection, - negative=None, - stage="val", - ) - if batch_idx < self.log_batches_per_epoch: - self.validation_step_outputs.extend( - detach_sample((anchor, pos_img), self.log_samples_per_batch) - ) + loss = self.loss_function(anchor_projection, positive_projection) else: - neg_img = batch["negative"] - negative_projection = self(neg_img) loss = self.loss_function( anchor_projection, positive_projection, negative_projection ) - self._log_metrics( - loss=loss, - anchor=anchor_projection, - positive=positive_projection, - negative=negative_projection, - stage="val", + self._log_metrics( + loss, anchor_projection, positive_projection, negative_projection, "val" + ) + if batch_idx < self.log_batches_per_epoch: + self.validation_step_outputs.extend( + detach_sample((anchor, pos_img, neg_img), self.log_samples_per_batch) ) - if batch_idx < self.log_batches_per_epoch: - self.validation_step_outputs.extend( - detach_sample( - (anchor, pos_img, neg_img), self.log_samples_per_batch - ) - ) return loss def on_validation_epoch_end(self) -> None: From 98b36de65e2b5d62ce0b99aac12a6296b9de3ca9 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 13 Nov 2024 14:10:36 -0800 Subject: [PATCH 27/44] exclude the negative pair from dataloader and forward pass --- viscy/representation/engine.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index f1153f36..604e9009 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -59,7 +59,7 @@ def forward(self, zis, zjs): # Find the valid pairs of positive samples positive_samples = torch.cat( [torch.arange(self.batch_size), torch.arange(self.batch_size)], dim=0 - ) + ).to(similarity_matrix.device) # Mask out unwanted pairs similarity_matrix = similarity_matrix[self.mask].view(2 * self.batch_size, -1) @@ -188,17 +188,26 @@ def training_step(self, batch: TripletSample, batch_idx: int) -> Tensor: if isinstance(self.loss_function, NTXentLoss): # Note: we assume the two augmented views are the anchor and positive samples loss = self.loss_function(anchor_projection, positive_projection) + self._log_metrics( + loss=loss, + anchor=anchor_projection, + positive=positive_projection, + negative=None, + stage="train", + ) else: + neg_img = batch["negative"] + negative_projection = self(neg_img) loss = self.loss_function( anchor_projection, positive_projection, negative_projection ) - self._log_metrics( - loss, - anchor_projection, - positive_projection, - negative_projection, - stage="train", - ) + self._log_metrics( + loss=loss, + anchor=anchor_projection, + positive=positive_projection, + negative=negative_projection, + stage="train", + ) if batch_idx < self.log_batches_per_epoch: self.training_step_outputs.extend( detach_sample( From 5bc2ff66746737947237d8f440d54c66380a04ba Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 13 Nov 2024 19:09:43 -0800 Subject: [PATCH 28/44] adding option using pytorch-metric-learning implementation and modifying previous to match same input args --- viscy/representation/engine.py | 112 ++++++++++++++++++++++----------- 1 file changed, 76 insertions(+), 36 deletions(-) diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index 604e9009..a73b1a90 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -12,23 +12,27 @@ from viscy.data.typing import TrackingIndex, TripletSample from viscy.representation.contrastive import ContrastiveEncoder from viscy.utils.log_images import detach_sample, render_images +from pytorch_metric_learning.losses import SelfSupervisedLoss +from pytorch_metric_learning.losses import NTXentLoss as NTXentLoss_pml _logger = logging.getLogger("lightning.pytorch") -class NTXentLoss(torch.nn.Module): +class NTXentLoss_viscy(torch.nn.Module): """ Normalized Temperature-scaled Cross Entropy Loss From Chen et.al, https://arxiv.org/abs/2002.05709 """ - def __init__(self, batch_size, temperature=0.5): - super(NTXentLoss, self).__init__() - self.batch_size = batch_size + def __init__( + self, + temperature=0.5, + criterion=torch.nn.CrossEntropyLoss(reduction="sum"), + ): + super(NTXentLoss_viscy, self).__init__() self.temperature = temperature - self.mask = self._get_correlated_mask(batch_size) - self.criterion = torch.nn.CrossEntropyLoss(reduction="sum") + self.criterion = criterion def _get_correlated_mask(self, batch_size): mask = torch.ones((2 * batch_size, 2 * batch_size), dtype=bool) @@ -36,37 +40,36 @@ def _get_correlated_mask(self, batch_size): for i in range(batch_size): mask[i, batch_size + i] = 0 mask[batch_size + i, i] = 0 + _logger.info(f"mask: {mask}") return mask @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32) - def forward(self, zis, zjs): + def forward(self, embeddings, labels): """ - zis and zjs are the output projections from the two augmented views - + embeddings = [zis, zjs] + zis and zjs are the output projections from the two augmented views. Here, we assume the two augmented views are the anchor and positive samples """ - # Concatenate representations along the batch dimension - representations = torch.cat([zis, zjs], dim=0) + # Get the batch size from tensor + batch_size = embeddings.shape[0] // 2 + + zis, zjs = torch.split(embeddings, batch_size, dim=0) # Cosine similarity similarity_matrix = F.cosine_similarity( - representations.unsqueeze(1), representations.unsqueeze(0), dim=2 + embeddings.unsqueeze(1), embeddings.unsqueeze(0), dim=2 ) - # Temperature scaling similarity_matrix = similarity_matrix / self.temperature - # Find the valid pairs of positive samples - positive_samples = torch.cat( - [torch.arange(self.batch_size), torch.arange(self.batch_size)], dim=0 - ).to(similarity_matrix.device) + mask = self._get_correlated_mask(batch_size).to(similarity_matrix.device) # Mask out unwanted pairs - similarity_matrix = similarity_matrix[self.mask].view(2 * self.batch_size, -1) + similarity_matrix = similarity_matrix[mask].view(2 * batch_size, -1) # Calculate NT-Xent Loss as cross-entropy - loss = self.criterion(similarity_matrix, positive_samples) - loss /= 2 * self.batch_size + loss = self.criterion(similarity_matrix, labels) + loss /= 2 * batch_size return loss @@ -84,7 +87,11 @@ def __init__( self, encoder: nn.Module | ContrastiveEncoder, loss_function: ( - nn.Module | nn.CosineEmbeddingLoss | nn.TripletMarginLoss | NTXentLoss + nn.Module + | nn.CosineEmbeddingLoss + | nn.TripletMarginLoss + | NTXentLoss_pml + | NTXentLoss_viscy ) = nn.TripletMarginLoss(margin=0.5), lr: float = 1e-3, schedule: Literal["WarmupCosine", "Constant"] = "Constant", @@ -185,9 +192,14 @@ def training_step(self, batch: TripletSample, batch_idx: int) -> Tensor: pos_img = batch["positive"] anchor_projection = self(anchor_img) positive_projection = self(pos_img) - if isinstance(self.loss_function, NTXentLoss): + if isinstance(self.loss_function, (NTXentLoss_pml, NTXentLoss_viscy)): + indices = torch.arange( + 0, anchor_projection.size(0), device=anchor_projection.device + ) + labels = torch.cat((indices, indices)) # Note: we assume the two augmented views are the anchor and positive samples - loss = self.loss_function(anchor_projection, positive_projection) + embeddings = torch.cat((anchor_projection, positive_projection)) + loss = self.loss_function(embeddings, labels) self._log_metrics( loss=loss, anchor=anchor_projection, @@ -195,6 +207,10 @@ def training_step(self, batch: TripletSample, batch_idx: int) -> Tensor: negative=None, stage="train", ) + if batch_idx < self.log_batches_per_epoch: + self.training_step_outputs.extend( + detach_sample((anchor_img, pos_img), self.log_samples_per_batch) + ) else: neg_img = batch["negative"] negative_projection = self(neg_img) @@ -208,12 +224,12 @@ def training_step(self, batch: TripletSample, batch_idx: int) -> Tensor: negative=negative_projection, stage="train", ) - if batch_idx < self.log_batches_per_epoch: - self.training_step_outputs.extend( - detach_sample( - (anchor_img, pos_img, neg_img), self.log_samples_per_batch + if batch_idx < self.log_batches_per_epoch: + self.training_step_outputs.extend( + detach_sample( + (anchor_img, pos_img, neg_img), self.log_samples_per_batch + ) ) - ) return loss def on_train_epoch_end(self) -> None: @@ -233,20 +249,44 @@ def validation_step(self, batch: TripletSample, batch_idx: int) -> Tensor: pos_img = batch["positive"] anchor_projection = self(anchor) positive_projection = self(pos_img) - if isinstance(self.loss_function, NTXentLoss): + if isinstance(self.loss_function, (NTXentLoss_pml, NTXentLoss_viscy)): + indices = torch.arange( + 0, anchor_projection.size(0), device=anchor_projection.device + ) + labels = torch.cat((indices, indices)) # Note: we assume the two augmented views are the anchor and positive samples - loss = self.loss_function(anchor_projection, positive_projection) + embeddings = torch.cat((anchor_projection, positive_projection)) + loss = self.loss_function(embeddings, labels) + self._log_metrics( + loss=loss, + anchor=anchor_projection, + positive=positive_projection, + negative=None, + stage="val", + ) + if batch_idx < self.log_batches_per_epoch: + self.validation_step_outputs.extend( + detach_sample((anchor, pos_img), self.log_samples_per_batch) + ) else: + neg_img = batch["negative"] + negative_projection = self(neg_img) loss = self.loss_function( anchor_projection, positive_projection, negative_projection ) - self._log_metrics( - loss, anchor_projection, positive_projection, negative_projection, "val" - ) - if batch_idx < self.log_batches_per_epoch: - self.validation_step_outputs.extend( - detach_sample((anchor, pos_img, neg_img), self.log_samples_per_batch) + self._log_metrics( + loss=loss, + anchor=anchor_projection, + positive=positive_projection, + negative=negative_projection, + stage="val", ) + if batch_idx < self.log_batches_per_epoch: + self.validation_step_outputs.extend( + detach_sample( + (anchor, pos_img, neg_img), self.log_samples_per_batch + ) + ) return loss def on_validation_epoch_end(self) -> None: From 472be4e23017362c23840cc5448cec444117ffa2 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 14 Nov 2024 14:34:05 -0800 Subject: [PATCH 29/44] removing our implementation of NTXentLoss and using pytorch metric --- viscy/representation/engine.py | 70 ++-------------------------------- 1 file changed, 4 insertions(+), 66 deletions(-) diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index a73b1a90..4ee53209 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -1,5 +1,5 @@ import logging -from typing import Literal, Sequence, TypedDict, Tuple +from typing import Literal, Sequence, Tuple, TypedDict import numpy as np import torch @@ -12,68 +12,10 @@ from viscy.data.typing import TrackingIndex, TripletSample from viscy.representation.contrastive import ContrastiveEncoder from viscy.utils.log_images import detach_sample, render_images -from pytorch_metric_learning.losses import SelfSupervisedLoss -from pytorch_metric_learning.losses import NTXentLoss as NTXentLoss_pml _logger = logging.getLogger("lightning.pytorch") -class NTXentLoss_viscy(torch.nn.Module): - """ - Normalized Temperature-scaled Cross Entropy Loss - - From Chen et.al, https://arxiv.org/abs/2002.05709 - """ - - def __init__( - self, - temperature=0.5, - criterion=torch.nn.CrossEntropyLoss(reduction="sum"), - ): - super(NTXentLoss_viscy, self).__init__() - self.temperature = temperature - self.criterion = criterion - - def _get_correlated_mask(self, batch_size): - mask = torch.ones((2 * batch_size, 2 * batch_size), dtype=bool) - mask = mask.fill_diagonal_(0) - for i in range(batch_size): - mask[i, batch_size + i] = 0 - mask[batch_size + i, i] = 0 - _logger.info(f"mask: {mask}") - return mask - - @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32) - def forward(self, embeddings, labels): - """ - embeddings = [zis, zjs] - zis and zjs are the output projections from the two augmented views. - Here, we assume the two augmented views are the anchor and positive samples - """ - # Get the batch size from tensor - batch_size = embeddings.shape[0] // 2 - - zis, zjs = torch.split(embeddings, batch_size, dim=0) - - # Cosine similarity - similarity_matrix = F.cosine_similarity( - embeddings.unsqueeze(1), embeddings.unsqueeze(0), dim=2 - ) - # Temperature scaling - similarity_matrix = similarity_matrix / self.temperature - - mask = self._get_correlated_mask(batch_size).to(similarity_matrix.device) - - # Mask out unwanted pairs - similarity_matrix = similarity_matrix[mask].view(2 * batch_size, -1) - - # Calculate NT-Xent Loss as cross-entropy - loss = self.criterion(similarity_matrix, labels) - loss /= 2 * batch_size - - return loss - - class ContrastivePrediction(TypedDict): features: Tensor projections: Tensor @@ -87,11 +29,7 @@ def __init__( self, encoder: nn.Module | ContrastiveEncoder, loss_function: ( - nn.Module - | nn.CosineEmbeddingLoss - | nn.TripletMarginLoss - | NTXentLoss_pml - | NTXentLoss_viscy + nn.Module | nn.CosineEmbeddingLoss | nn.TripletMarginLoss | NTXentLoss ) = nn.TripletMarginLoss(margin=0.5), lr: float = 1e-3, schedule: Literal["WarmupCosine", "Constant"] = "Constant", @@ -192,7 +130,7 @@ def training_step(self, batch: TripletSample, batch_idx: int) -> Tensor: pos_img = batch["positive"] anchor_projection = self(anchor_img) positive_projection = self(pos_img) - if isinstance(self.loss_function, (NTXentLoss_pml, NTXentLoss_viscy)): + if isinstance(self.loss_function, NTXentLoss): indices = torch.arange( 0, anchor_projection.size(0), device=anchor_projection.device ) @@ -249,7 +187,7 @@ def validation_step(self, batch: TripletSample, batch_idx: int) -> Tensor: pos_img = batch["positive"] anchor_projection = self(anchor) positive_projection = self(pos_img) - if isinstance(self.loss_function, (NTXentLoss_pml, NTXentLoss_viscy)): + if isinstance(self.loss_function, NTXentLoss): indices = torch.arange( 0, anchor_projection.size(0), device=anchor_projection.device ) From 66b389990a1351d7cdb88649dcf8f4cdedd4d580 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 14 Nov 2024 14:46:05 -0800 Subject: [PATCH 30/44] ruff --- viscy/representation/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index 4ee53209..927b6169 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -1,5 +1,5 @@ import logging -from typing import Literal, Sequence, Tuple, TypedDict +from typing import Literal, Sequence, TypedDict import numpy as np import torch From 9cc35cf309316fe479b8c1cde9f0f312528f2136 Mon Sep 17 00:00:00 2001 From: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> Date: Mon, 2 Dec 2024 15:21:36 -0800 Subject: [PATCH 31/44] remove blank line diff --- viscy/representation/engine.py | 1 - 1 file changed, 1 deletion(-) diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index 927b6169..eb93bf25 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -80,7 +80,6 @@ def _log_metrics( logger=True, sync_dist=True, ) - cosine_sim_pos = F.cosine_similarity(anchor, positive, dim=1).mean() euclidean_dist_pos = F.pairwise_distance(anchor, positive).mean() log_metric_dict = { From 3c96f136c241f95c168349bc6664451da543353d Mon Sep 17 00:00:00 2001 From: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> Date: Mon, 2 Dec 2024 15:21:44 -0800 Subject: [PATCH 32/44] remove blank line diff --- viscy/data/triplet.py | 1 - 1 file changed, 1 deletion(-) diff --git a/viscy/data/triplet.py b/viscy/data/triplet.py index b348a41e..95717dc7 100644 --- a/viscy/data/triplet.py +++ b/viscy/data/triplet.py @@ -262,7 +262,6 @@ def __getitem__(self, index: int) -> TripletSample: norm_meta=anchor_norm, ) sample = {"anchor": anchor_patch} - if self.fit: if self.return_negative: sample.update({"positive": positive_patch, "negative": negative_patch}) From d3d471579684e360db089ce68cee4129309630c7 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 9 Dec 2024 22:18:37 -0800 Subject: [PATCH 33/44] simplying the engine --- viscy/representation/engine.py | 78 +++++++++++++--------------------- 1 file changed, 30 insertions(+), 48 deletions(-) diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index eb93bf25..ca67263a 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -111,6 +111,16 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): key, grid, self.current_epoch, dataformats="HWC" ) + def _log_step_samples(self, batch_idx, samples, stage: Literal["train", "val"]): + """Common method for logging step samples""" + if batch_idx < self.log_batches_per_epoch: + output_list = ( + self.training_step_outputs + if stage == "train" + else self.validation_step_outputs + ) + output_list.extend(detach_sample(samples, self.log_samples_per_batch)) + def log_embedding_umap(self, embeddings: Tensor, tag: str): _logger.debug(f"Computing UMAP for {tag} embeddings.") umap = UMAP(n_components=2) @@ -129,6 +139,7 @@ def training_step(self, batch: TripletSample, batch_idx: int) -> Tensor: pos_img = batch["positive"] anchor_projection = self(anchor_img) positive_projection = self(pos_img) + negative_projection = None if isinstance(self.loss_function, NTXentLoss): indices = torch.arange( 0, anchor_projection.size(0), device=anchor_projection.device @@ -137,36 +148,21 @@ def training_step(self, batch: TripletSample, batch_idx: int) -> Tensor: # Note: we assume the two augmented views are the anchor and positive samples embeddings = torch.cat((anchor_projection, positive_projection)) loss = self.loss_function(embeddings, labels) - self._log_metrics( - loss=loss, - anchor=anchor_projection, - positive=positive_projection, - negative=None, - stage="train", - ) - if batch_idx < self.log_batches_per_epoch: - self.training_step_outputs.extend( - detach_sample((anchor_img, pos_img), self.log_samples_per_batch) - ) + self._log_step_samples(batch_idx, (anchor_img, pos_img), "train") else: neg_img = batch["negative"] negative_projection = self(neg_img) loss = self.loss_function( anchor_projection, positive_projection, negative_projection ) - self._log_metrics( - loss=loss, - anchor=anchor_projection, - positive=positive_projection, - negative=negative_projection, - stage="train", - ) - if batch_idx < self.log_batches_per_epoch: - self.training_step_outputs.extend( - detach_sample( - (anchor_img, pos_img, neg_img), self.log_samples_per_batch - ) - ) + self._log_step_samples(batch_idx, (anchor_img, pos_img, neg_img), "train") + self._log_metrics( + loss=loss, + anchor=anchor_projection, + positive=positive_projection, + negative=negative_projection, + stage="train", + ) return loss def on_train_epoch_end(self) -> None: @@ -186,6 +182,7 @@ def validation_step(self, batch: TripletSample, batch_idx: int) -> Tensor: pos_img = batch["positive"] anchor_projection = self(anchor) positive_projection = self(pos_img) + negative_projection = None if isinstance(self.loss_function, NTXentLoss): indices = torch.arange( 0, anchor_projection.size(0), device=anchor_projection.device @@ -194,36 +191,21 @@ def validation_step(self, batch: TripletSample, batch_idx: int) -> Tensor: # Note: we assume the two augmented views are the anchor and positive samples embeddings = torch.cat((anchor_projection, positive_projection)) loss = self.loss_function(embeddings, labels) - self._log_metrics( - loss=loss, - anchor=anchor_projection, - positive=positive_projection, - negative=None, - stage="val", - ) - if batch_idx < self.log_batches_per_epoch: - self.validation_step_outputs.extend( - detach_sample((anchor, pos_img), self.log_samples_per_batch) - ) + self._log_step_samples(batch_idx, (anchor, pos_img), "val") else: neg_img = batch["negative"] negative_projection = self(neg_img) loss = self.loss_function( anchor_projection, positive_projection, negative_projection ) - self._log_metrics( - loss=loss, - anchor=anchor_projection, - positive=positive_projection, - negative=negative_projection, - stage="val", - ) - if batch_idx < self.log_batches_per_epoch: - self.validation_step_outputs.extend( - detach_sample( - (anchor, pos_img, neg_img), self.log_samples_per_batch - ) - ) + self._log_step_samples(batch_idx, (anchor, pos_img, neg_img), "val") + self._log_metrics( + loss=loss, + anchor=anchor_projection, + positive=positive_projection, + negative=negative_projection, + stage="val", + ) return loss def on_validation_epoch_end(self) -> None: From 70dfb5d50efcccef7605c9e77d86958ff91411bf Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 21 Nov 2024 14:53:54 -0800 Subject: [PATCH 34/44] explicit target shape argument in the HCS data module --- viscy/data/hcs.py | 74 +++++++++++++++++++++++++++++------------------ 1 file changed, 46 insertions(+), 28 deletions(-) diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index c4087941..29bbd8cc 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -274,33 +274,51 @@ def __getitem__(self, index: int) -> Sample: class HCSDataModule(LightningDataModule): - """Lightning data module for a preprocessed HCS NGFF Store. - - :param str data_path: path to the data store - :param str | Sequence[str] source_channel: name(s) of the source channel, - e.g. ``'Phase'`` - :param str | Sequence[str] target_channel: name(s) of the target channel, - e.g. ``['Nuclei', 'Membrane']`` - :param int z_window_size: Z window size of the 2.5D U-Net, 1 for 2D - :param float split_ratio: split ratio of the training subset in the fit stage, - e.g. 0.8 means a 80/20 split between training/validation, - by default 0.8 - :param int batch_size: batch size, defaults to 16 - :param int num_workers: number of data-loading workers, defaults to 8 - :param Literal["2D", "UNeXt2", "2.5D", "3D"] architecture: U-Net architecture, - defaults to "2.5D" - :param tuple[int, int] yx_patch_size: patch size in (Y, X), - defaults to (256, 256) - :param list[MapTransform] normalizations: MONAI dictionary transforms - applied to selected channels, defaults to [] (no normalization) - :param list[MapTransform] augmentations: MONAI dictionary transforms - applied to the training set, defaults to [] (no augmentation) - :param bool caching: whether to decompress all the images and cache the result, - will store in ``/tmp/$SLURM_JOB_ID/`` if available, - defaults to False - :param Path | None ground_truth_masks: path to the ground truth masks, + """ + Lightning data module for a preprocessed HCS NGFF Store. + + Parameters + ---------- + data_path : str + Path to the data store. + source_channel : str or Sequence[str] + Name(s) of the source channel, e.g. 'Phase'. + target_channel : str or Sequence[str] + Name(s) of the target channel, e.g. ['Nuclei', 'Membrane']. + z_window_size : int + Z window size of the 2.5D U-Net, 1 for 2D. + split_ratio : float, optional + Split ratio of the training subset in the fit stage, + e.g. 0.8 means an 80/20 split between training/validation, + by default 0.8. + batch_size : int, optional + Batch size, defaults to 16. + num_workers : int, optional + Number of data-loading workers, defaults to 8. + architecture : Literal["2D", "UNeXt2", "2.5D", "3D"], optional + U-Net architecture, defaults to "2.5D". + yx_patch_size : tuple[int, int], optional + Patch size in (Y, X), defaults to (256, 256). + normalizations : list of MapTransform, optional + MONAI dictionary transforms applied to selected channels, + defaults to ``[]`` (no normalization). + augmentations : list of MapTransform, optional + MONAI dictionary transforms applied to the training set, + defaults to ``[]`` (no augmentation). + caching : bool, optional + Whether to decompress all the images and cache the result, + will store in `/tmp/$SLURM_JOB_ID/` if available, + defaults to False. + ground_truth_masks : Path or None, optional + Path to the ground truth masks, used in the test stage to compute segmentation metrics, - defaults to None + defaults to None. + persistent_workers : bool, optional + Whether to keep the workers alive between fitting epochs, + defaults to False. + prefetch_factor : int or None, optional + Number of samples loaded in advance by each worker during fitting, + defaults to None (2 per PyTorch default). """ def __init__( @@ -312,7 +330,7 @@ def __init__( split_ratio: float = 0.8, batch_size: int = 16, num_workers: int = 8, - architecture: Literal["2D", "UNeXt2", "2.5D", "3D", "fcmae"] = "2.5D", + target_2d: bool = False, yx_patch_size: tuple[int, int] = (256, 256), normalizations: list[MapTransform] = [], augmentations: list[MapTransform] = [], @@ -327,7 +345,7 @@ def __init__( self.target_channel = _ensure_channel_list(target_channel) self.batch_size = batch_size self.num_workers = num_workers - self.target_2d = False if architecture in ["UNeXt2", "3D", "fcmae"] else True + self.target_2d = target_2d self.z_window_size = z_window_size self.split_ratio = split_ratio self.yx_patch_size = yx_patch_size From a845847bcf7bed63e9b620dd12be3eb46506c993 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 21 Nov 2024 14:54:34 -0800 Subject: [PATCH 35/44] Revert "explicit target shape argument in the HCS data module" This reverts commit 464d4c9ed152ddae1d61652158328f203c78af40. --- viscy/data/hcs.py | 74 ++++++++++++++++++----------------------------- 1 file changed, 28 insertions(+), 46 deletions(-) diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index 29bbd8cc..c4087941 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -274,51 +274,33 @@ def __getitem__(self, index: int) -> Sample: class HCSDataModule(LightningDataModule): - """ - Lightning data module for a preprocessed HCS NGFF Store. - - Parameters - ---------- - data_path : str - Path to the data store. - source_channel : str or Sequence[str] - Name(s) of the source channel, e.g. 'Phase'. - target_channel : str or Sequence[str] - Name(s) of the target channel, e.g. ['Nuclei', 'Membrane']. - z_window_size : int - Z window size of the 2.5D U-Net, 1 for 2D. - split_ratio : float, optional - Split ratio of the training subset in the fit stage, - e.g. 0.8 means an 80/20 split between training/validation, - by default 0.8. - batch_size : int, optional - Batch size, defaults to 16. - num_workers : int, optional - Number of data-loading workers, defaults to 8. - architecture : Literal["2D", "UNeXt2", "2.5D", "3D"], optional - U-Net architecture, defaults to "2.5D". - yx_patch_size : tuple[int, int], optional - Patch size in (Y, X), defaults to (256, 256). - normalizations : list of MapTransform, optional - MONAI dictionary transforms applied to selected channels, - defaults to ``[]`` (no normalization). - augmentations : list of MapTransform, optional - MONAI dictionary transforms applied to the training set, - defaults to ``[]`` (no augmentation). - caching : bool, optional - Whether to decompress all the images and cache the result, - will store in `/tmp/$SLURM_JOB_ID/` if available, - defaults to False. - ground_truth_masks : Path or None, optional - Path to the ground truth masks, + """Lightning data module for a preprocessed HCS NGFF Store. + + :param str data_path: path to the data store + :param str | Sequence[str] source_channel: name(s) of the source channel, + e.g. ``'Phase'`` + :param str | Sequence[str] target_channel: name(s) of the target channel, + e.g. ``['Nuclei', 'Membrane']`` + :param int z_window_size: Z window size of the 2.5D U-Net, 1 for 2D + :param float split_ratio: split ratio of the training subset in the fit stage, + e.g. 0.8 means a 80/20 split between training/validation, + by default 0.8 + :param int batch_size: batch size, defaults to 16 + :param int num_workers: number of data-loading workers, defaults to 8 + :param Literal["2D", "UNeXt2", "2.5D", "3D"] architecture: U-Net architecture, + defaults to "2.5D" + :param tuple[int, int] yx_patch_size: patch size in (Y, X), + defaults to (256, 256) + :param list[MapTransform] normalizations: MONAI dictionary transforms + applied to selected channels, defaults to [] (no normalization) + :param list[MapTransform] augmentations: MONAI dictionary transforms + applied to the training set, defaults to [] (no augmentation) + :param bool caching: whether to decompress all the images and cache the result, + will store in ``/tmp/$SLURM_JOB_ID/`` if available, + defaults to False + :param Path | None ground_truth_masks: path to the ground truth masks, used in the test stage to compute segmentation metrics, - defaults to None. - persistent_workers : bool, optional - Whether to keep the workers alive between fitting epochs, - defaults to False. - prefetch_factor : int or None, optional - Number of samples loaded in advance by each worker during fitting, - defaults to None (2 per PyTorch default). + defaults to None """ def __init__( @@ -330,7 +312,7 @@ def __init__( split_ratio: float = 0.8, batch_size: int = 16, num_workers: int = 8, - target_2d: bool = False, + architecture: Literal["2D", "UNeXt2", "2.5D", "3D", "fcmae"] = "2.5D", yx_patch_size: tuple[int, int] = (256, 256), normalizations: list[MapTransform] = [], augmentations: list[MapTransform] = [], @@ -345,7 +327,7 @@ def __init__( self.target_channel = _ensure_channel_list(target_channel) self.batch_size = batch_size self.num_workers = num_workers - self.target_2d = target_2d + self.target_2d = False if architecture in ["UNeXt2", "3D", "fcmae"] else True self.z_window_size = z_window_size self.split_ratio = split_ratio self.yx_patch_size = yx_patch_size From a6e2318c04aa449c027e9076640a587a6c5b05ec Mon Sep 17 00:00:00 2001 From: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> Date: Wed, 27 Nov 2024 14:01:40 -0800 Subject: [PATCH 36/44] Explicit target shape argument in the HCS data module (#212) * explicit target shape argument in the HCS data module * update docstring * update test cases --- tests/data/test_data.py | 7 ++-- viscy/data/hcs.py | 75 ++++++++++++++++++++++++++--------------- 2 files changed, 51 insertions(+), 31 deletions(-) diff --git a/tests/data/test_data.py b/tests/data/test_data.py index a75f8da8..c71488c4 100644 --- a/tests/data/test_data.py +++ b/tests/data/test_data.py @@ -58,7 +58,7 @@ def test_datamodule_setup_fit(preprocessed_hcs_dataset, multi_sample_augmentatio batch_size=batch_size, num_workers=0, augmentations=transforms, - architecture="3D", + target_2d=False, split_ratio=split_ratio, yx_patch_size=yx_patch_size, ) @@ -78,9 +78,9 @@ def test_datamodule_setup_fit(preprocessed_hcs_dataset, multi_sample_augmentatio ) -def test_datamodule_setup_predict(preprocessed_hcs_dataset): +@mark.parametrize("z_window_size", [1, 5]) +def test_datamodule_setup_predict(preprocessed_hcs_dataset, z_window_size): data_path = preprocessed_hcs_dataset - z_window_size = 5 channel_split = 2 with open_ome_zarr(data_path) as dataset: channel_names = dataset.channel_names @@ -91,6 +91,7 @@ def test_datamodule_setup_predict(preprocessed_hcs_dataset): source_channel=channel_names[:channel_split], target_channel=channel_names[channel_split:], z_window_size=z_window_size, + target_2d=bool(z_window_size == 1), batch_size=2, num_workers=0, ) diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index c4087941..88111cc3 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -274,33 +274,52 @@ def __getitem__(self, index: int) -> Sample: class HCSDataModule(LightningDataModule): - """Lightning data module for a preprocessed HCS NGFF Store. - - :param str data_path: path to the data store - :param str | Sequence[str] source_channel: name(s) of the source channel, - e.g. ``'Phase'`` - :param str | Sequence[str] target_channel: name(s) of the target channel, - e.g. ``['Nuclei', 'Membrane']`` - :param int z_window_size: Z window size of the 2.5D U-Net, 1 for 2D - :param float split_ratio: split ratio of the training subset in the fit stage, - e.g. 0.8 means a 80/20 split between training/validation, - by default 0.8 - :param int batch_size: batch size, defaults to 16 - :param int num_workers: number of data-loading workers, defaults to 8 - :param Literal["2D", "UNeXt2", "2.5D", "3D"] architecture: U-Net architecture, - defaults to "2.5D" - :param tuple[int, int] yx_patch_size: patch size in (Y, X), - defaults to (256, 256) - :param list[MapTransform] normalizations: MONAI dictionary transforms - applied to selected channels, defaults to [] (no normalization) - :param list[MapTransform] augmentations: MONAI dictionary transforms - applied to the training set, defaults to [] (no augmentation) - :param bool caching: whether to decompress all the images and cache the result, - will store in ``/tmp/$SLURM_JOB_ID/`` if available, - defaults to False - :param Path | None ground_truth_masks: path to the ground truth masks, + """ + Lightning data module for a preprocessed HCS NGFF Store. + + Parameters + ---------- + data_path : str + Path to the data store. + source_channel : str or Sequence[str] + Name(s) of the source channel, e.g. 'Phase'. + target_channel : str or Sequence[str] + Name(s) of the target channel, e.g. ['Nuclei', 'Membrane']. + z_window_size : int + Z window size of the 2.5D U-Net, 1 for 2D. + split_ratio : float, optional + Split ratio of the training subset in the fit stage, + e.g. 0.8 means an 80/20 split between training/validation, + by default 0.8. + batch_size : int, optional + Batch size, defaults to 16. + num_workers : int, optional + Number of data-loading workers, defaults to 8. + target_2d : bool, optional + Whether the target is 2D (e.g. in a 2.5D model), + defaults to False. + yx_patch_size : tuple[int, int], optional + Patch size in (Y, X), defaults to (256, 256). + normalizations : list of MapTransform, optional + MONAI dictionary transforms applied to selected channels, + defaults to ``[]`` (no normalization). + augmentations : list of MapTransform, optional + MONAI dictionary transforms applied to the training set, + defaults to ``[]`` (no augmentation). + caching : bool, optional + Whether to decompress all the images and cache the result, + will store in `/tmp/$SLURM_JOB_ID/` if available, + defaults to False. + ground_truth_masks : Path or None, optional + Path to the ground truth masks, used in the test stage to compute segmentation metrics, - defaults to None + defaults to None. + persistent_workers : bool, optional + Whether to keep the workers alive between fitting epochs, + defaults to False. + prefetch_factor : int or None, optional + Number of samples loaded in advance by each worker during fitting, + defaults to None (2 per PyTorch default). """ def __init__( @@ -312,7 +331,7 @@ def __init__( split_ratio: float = 0.8, batch_size: int = 16, num_workers: int = 8, - architecture: Literal["2D", "UNeXt2", "2.5D", "3D", "fcmae"] = "2.5D", + target_2d: bool = False, yx_patch_size: tuple[int, int] = (256, 256), normalizations: list[MapTransform] = [], augmentations: list[MapTransform] = [], @@ -327,7 +346,7 @@ def __init__( self.target_channel = _ensure_channel_list(target_channel) self.batch_size = batch_size self.num_workers = num_workers - self.target_2d = False if architecture in ["UNeXt2", "3D", "fcmae"] else True + self.target_2d = target_2d self.z_window_size = z_window_size self.split_ratio = split_ratio self.yx_patch_size = yx_patch_size From fa74b0d0f75b9013987ed80b3714210385ddb86f Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 2 Dec 2024 15:29:58 -0800 Subject: [PATCH 37/44] Gradio example (#158) * initial demo * using the predict_step * modifying paths to chkpt and example pngs * updating gradio as the one on Huggingface --- examples/gradio/demo_gradio.py | 144 +++++++++++++++++++++++++++++++++ 1 file changed, 144 insertions(+) create mode 100644 examples/gradio/demo_gradio.py diff --git a/examples/gradio/demo_gradio.py b/examples/gradio/demo_gradio.py new file mode 100644 index 00000000..bcee613c --- /dev/null +++ b/examples/gradio/demo_gradio.py @@ -0,0 +1,144 @@ +import gradio as gr +import torch +from viscy.light.engine import VSUNet +from huggingface_hub import hf_hub_download +from numpy.typing import ArrayLike +import numpy as np +from skimage import exposure + + +class VSGradio: + def __init__(self, model_config, model_ckpt_path): + self.model_config = model_config + self.model_ckpt_path = model_ckpt_path + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.model = None + self.load_model() + + def load_model(self): + # Load the model checkpoint and move it to the correct device (GPU or CPU) + self.model = VSUNet.load_from_checkpoint( + self.model_ckpt_path, + architecture="UNeXt2_2D", + model_config=self.model_config, + ) + self.model.to(self.device) # Move the model to the correct device (GPU/CPU) + self.model.eval() + + def normalize_fov(self, input: ArrayLike): + "Normalizing the fov with zero mean and unit variance" + mean = np.mean(input) + std = np.std(input) + return (input - mean) / std + + def predict(self, inp): + # Normalize the input and convert to tensor + inp = self.normalize_fov(inp) + inp = torch.from_numpy(np.array(inp).astype(np.float32)) + + # Prepare the input dictionary and move input to the correct device (GPU or CPU) + test_dict = dict( + index=None, + source=inp.unsqueeze(0).unsqueeze(0).unsqueeze(0).to(self.device), + ) + + # Run model inference + with torch.inference_mode(): + self.model.on_predict_start() # Necessary preprocessing for the model + pred = ( + self.model.predict_step(test_dict, 0, 0).cpu().numpy() + ) # Move output back to CPU for post-processing + + # Post-process the model output and rescale intensity + nuc_pred = pred[0, 0, 0] + mem_pred = pred[0, 1, 0] + nuc_pred = exposure.rescale_intensity(nuc_pred, out_range=(0, 1)) + mem_pred = exposure.rescale_intensity(mem_pred, out_range=(0, 1)) + + return nuc_pred, mem_pred + + +# Load the custom CSS from the file +def load_css(file_path): + with open(file_path, "r") as file: + return file.read() + + +# %% +if __name__ == "__main__": + # Download the model checkpoint from Hugging Face + model_ckpt_path = hf_hub_download( + repo_id="compmicro-czb/VSCyto2D", filename="epoch=399-step=23200.ckpt" + ) + + # Model configuration + model_config = { + "in_channels": 1, + "out_channels": 2, + "encoder_blocks": [3, 3, 9, 3], + "dims": [96, 192, 384, 768], + "decoder_conv_blocks": 2, + "stem_kernel_size": [1, 2, 2], + "in_stack_depth": 1, + "pretraining": False, + } + + # Initialize the Gradio app using Blocks + with gr.Blocks(css=load_css("style.css")) as demo: + # Title and description + gr.HTML( + "
Image Translation (Virtual Staining) of cellular landmark organelles
" + ) + # Improved description block with better formatting + gr.HTML( + """ +
+

Model: VSCyto2D

+

+ Input: label-free image (e.g., QPI or phase contrast)
+ Output: two virtually stained channels: one for the nucleus and one for the cell membrane. +

+

+ Check out our preprint: + Liu et al.,Robust virtual staining of landmark organelles +

+
+ """ + ) + + vsgradio = VSGradio(model_config, model_ckpt_path) + + # Layout for input and output images + with gr.Row(): + input_image = gr.Image(type="numpy", image_mode="L", label="Upload Image") + with gr.Column(): + output_nucleus = gr.Image(type="numpy", label="VS Nucleus") + output_membrane = gr.Image(type="numpy", label="VS Membrane") + + # Button to trigger prediction + submit_button = gr.Button("Submit") + + # Define what happens when the button is clicked + submit_button.click( + vsgradio.predict, + inputs=input_image, + outputs=[output_nucleus, output_membrane], + ) + + # Example images and article + gr.Examples( + examples=["examples/a549.png", "examples/hek.png"], inputs=input_image + ) + + # Article or footer information + gr.HTML( + """ +
+

Model trained primarily on HEK293T, BJ5, and A549 cells. For best results, use quantitative phase images (QPI) or Zernike phase contrast.

+

For training, inference and evaluation of the model refer to the GitHub repository.

+
+ """ + ) + + # Launch the Gradio app + demo.launch() From ea79aad410c57f9fbc07599b1f35620c242637ab Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 18 Dec 2024 16:49:43 -0800 Subject: [PATCH 38/44] adding configurable phate arguments via config --- .../examples_cli/predict.yml | 7 ++- viscy/representation/embedding_writer.py | 51 +++++++++++++++---- .../evaluation/dimensionality_reduction.py | 3 +- 3 files changed, 48 insertions(+), 13 deletions(-) diff --git a/applications/contrastive_phenotyping/examples_cli/predict.yml b/applications/contrastive_phenotyping/examples_cli/predict.yml index 4dd218cf..622f273f 100644 --- a/applications/contrastive_phenotyping/examples_cli/predict.yml +++ b/applications/contrastive_phenotyping/examples_cli/predict.yml @@ -8,7 +8,12 @@ trainer: callbacks: - class_path: viscy.representation.embedding_writer.EmbeddingWriter init_args: - output_path: "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/contrastive_tune_augmentations/predict/test_prediction_code.zarr" + output_path: "/path/to/output.zarr" + phate_kwargs: + n_components: 2 + knn: 10 + decay: 50 + gamma: 1 # edit the following lines to specify logging path # - class_path: lightning.pytorch.loggers.TensorBoardLogger # init_args: diff --git a/viscy/representation/embedding_writer.py b/viscy/representation/embedding_writer.py index 8ed6b74d..f4bb9c7c 100644 --- a/viscy/representation/embedding_writer.py +++ b/viscy/representation/embedding_writer.py @@ -12,8 +12,8 @@ from viscy.data.triplet import INDEX_COLUMNS from viscy.representation.engine import ContrastivePrediction from viscy.representation.evaluation.dimensionality_reduction import ( - _fit_transform_umap, _fit_transform_phate, + _fit_transform_umap, ) __all__ = ["read_embedding_dataset", "EmbeddingWriter"] @@ -54,16 +54,38 @@ class EmbeddingWriter(BasePredictionWriter): Path to the zarr store. write_interval : Literal["batch", "epoch", "batch_and_epoch"], optional When to write the embeddings, by default 'epoch'. + phate_kwargs : dict, optional + Keyword arguments passed to PHATE, by default None. + Common parameters include: + - knn: int, number of nearest neighbors (default: 5) + - decay: int, decay rate for kernel (default: 40) + - n_jobs: int, number of jobs for parallel processing + - t: int, number of diffusion steps + - potential_method: str, potential method to use + See phate.PHATE for all available parameters. """ def __init__( self, output_path: Path, write_interval: Literal["batch", "epoch", "batch_and_epoch"] = "epoch", + phate_kwargs: dict | None = None, ): super().__init__(write_interval) self.output_path = Path(output_path) + # Set default PHATE parameters + default_phate_kwargs = { + "n_components": 2, + "knn": 5, + "decay": 40, + "n_jobs": -1, + "random_state": 42, + } + if phate_kwargs is not None: + default_phate_kwargs.update(phate_kwargs) + self.phate_kwargs = default_phate_kwargs + def on_predict_start(self, trainer: Trainer, pl_module: LightningModule) -> None: if self.output_path.exists(): raise FileExistsError(f"Output path {self.output_path} already exists.") @@ -76,7 +98,7 @@ def write_on_epoch_end( predictions: Sequence[ContrastivePrediction], batch_indices: Sequence[int], ) -> None: - """Write predictions and the 2-component UMAP of features to a zarr store. + """Write predictions and dimensionality reductions to a zarr store. Parameters ---------- @@ -92,15 +114,21 @@ def write_on_epoch_end( features = _move_and_stack_embeddings(predictions, "features") projections = _move_and_stack_embeddings(predictions, "projections") ultrack_indices = pd.concat([pd.DataFrame(p["index"]) for p in predictions]) - _logger.info(f"Computing UMAP embeddings for {len(features)} samples.") + + _logger.info( + f"Computing dimensionality reductions for {len(features)} samples." + ) _, umap = _fit_transform_umap(features, n_components=2, normalize=True) _, phate = _fit_transform_phate( - features, n_components=2, knn=5, decay=40, n_jobs=-1 + features, + **self.phate_kwargs, ) - ultrack_indices["UMAP1"] = umap[:, 0] - ultrack_indices["UMAP2"] = umap[:, 1] - ultrack_indices["PHATE1"] = phate[:, 0] - ultrack_indices["PHATE2"] = phate[:, 1] + + # Add dimensionality reduction coordinates + ultrack_indices["UMAP1"], ultrack_indices["UMAP2"] = umap[:, 0], umap[:, 1] + ultrack_indices["PHATE1"], ultrack_indices["PHATE2"] = phate[:, 0], phate[:, 1] + + # Create multi-index and dataset index = pd.MultiIndex.from_frame(ultrack_indices) dataset = Dataset( { @@ -109,6 +137,7 @@ def write_on_epoch_end( }, coords={"sample": index}, ).reset_index("sample") - _logger.debug(f"Wrtiting predictions dataset:\n{dataset}") - zarr_store = dataset.to_zarr(self.output_path, mode="w") - zarr_store.close() + + _logger.debug(f"Writing predictions dataset:\n{dataset}") + with dataset.to_zarr(self.output_path, mode="w") as zarr_store: + zarr_store.close() diff --git a/viscy/representation/evaluation/dimensionality_reduction.py b/viscy/representation/evaluation/dimensionality_reduction.py index 721bb8ed..83faccb1 100644 --- a/viscy/representation/evaluation/dimensionality_reduction.py +++ b/viscy/representation/evaluation/dimensionality_reduction.py @@ -105,10 +105,11 @@ def _fit_transform_phate( knn: int = 5, decay: int = 40, n_jobs: int = -1, + **phate_kwargs, ) -> tuple[phate.PHATE, NDArray]: """Fit PHATE model and transform embeddings.""" phate_model = phate.PHATE( - n_components=n_components, knn=knn, decay=decay, n_jobs=n_jobs + n_components=n_components, knn=knn, decay=decay, n_jobs=n_jobs, **phate_kwargs ) phate_embedding = phate_model.fit_transform(embeddings) return phate_model, phate_embedding From 47d84b787537e9310442ca4719db95b655043ddc Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 18 Dec 2024 16:53:46 -0800 Subject: [PATCH 39/44] script to recompute phate and overwrite the previous phate data --- viscy/scripts/recompute_phate.py | 46 ++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 viscy/scripts/recompute_phate.py diff --git a/viscy/scripts/recompute_phate.py b/viscy/scripts/recompute_phate.py new file mode 100644 index 00000000..1db0794f --- /dev/null +++ b/viscy/scripts/recompute_phate.py @@ -0,0 +1,46 @@ +import logging +from pathlib import Path +from typing import Any + +from xarray import open_zarr +from viscy.representation.evaluation.dimensionality_reduction import ( + _fit_transform_phate, +) + +_logger = logging.getLogger(__name__) + + +def update_phate_embeddings( + dataset_path: Path, + phate_kwargs: dict[str, Any], +) -> None: + """ + Update PHATE embeddings in an existing dataset with new parameters. + + Parameters + ---------- + dataset_path : Path + Path to the zarr store containing embeddings + phate_kwargs : dict + New PHATE parameters to use for recomputing embeddings. + Common parameters include: + - n_components: int, number of dimensions (default: 2) + - knn: int, number of nearest neighbors (default: 5) + - decay: int, decay rate for kernel (default: 40) + - n_jobs: int, number of jobs for parallel processing + - t: int, number of diffusion steps + - gamma: float, gamma parameter for kernel + """ + # Load dataset + dataset = open_zarr(dataset_path, mode="r+") + features = dataset["features"].values + + # Compute new PHATE embeddings + _logger.info(f"Computing PHATE embeddings with parameters: {phate_kwargs}") + _, phate = _fit_transform_phate(features, **phate_kwargs) + + # Update PHATE coordinates + dataset["PHATE1"].values = phate[:, 0] + dataset["PHATE2"].values = phate[:, 1] + + _logger.info(f"Updated PHATE embeddings in {dataset_path}") From 2546cc90e4f7375dc821c54facb5b2d7074dce10 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 18 Dec 2024 17:49:49 -0800 Subject: [PATCH 40/44] ruff --- viscy/scripts/recompute_phate.py | 1 + 1 file changed, 1 insertion(+) diff --git a/viscy/scripts/recompute_phate.py b/viscy/scripts/recompute_phate.py index 1db0794f..41133c97 100644 --- a/viscy/scripts/recompute_phate.py +++ b/viscy/scripts/recompute_phate.py @@ -3,6 +3,7 @@ from typing import Any from xarray import open_zarr + from viscy.representation.evaluation.dimensionality_reduction import ( _fit_transform_phate, ) From f40f316096053e49c7b5a661a6125cc6edf0bf6b Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Fri, 20 Dec 2024 14:00:45 -0800 Subject: [PATCH 41/44] solving redundancies --- .../evaluation/dimensionality_reduction.py | 28 +++++++++++++++---- viscy/scripts/recompute_phate.py | 13 ++------- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/viscy/representation/evaluation/dimensionality_reduction.py b/viscy/representation/evaluation/dimensionality_reduction.py index 83faccb1..a6645fe7 100644 --- a/viscy/representation/evaluation/dimensionality_reduction.py +++ b/viscy/representation/evaluation/dimensionality_reduction.py @@ -14,22 +14,25 @@ def compute_phate( n_components: int = None, knn: int = 5, decay: int = 40, + update_dataset: bool = False, **phate_kwargs, ) -> tuple[phate.PHATE, NDArray]: """ - Compute PHATE embeddings for features - + Compute PHATE embeddings for features and optionally update dataset. Parameters ---------- - embedding_dataset : xarray.Dataset - The dataset containing embeddings, timepoints, fov_name, and track_id. + embedding_dataset : xarray.Dataset or NDArray + The dataset containing embeddings, timepoints, fov_name, and track_id, + or a numpy array of embeddings. n_components : int, optional Number of dimensions in the PHATE embedding, by default None knn : int, optional Number of nearest neighbors to use in the KNN graph, by default 5 decay : int, optional Decay parameter for the Markov operator, by default 40 + update_dataset : bool, optional + Whether to update the PHATE coordinates in the dataset, by default False phate_kwargs : dict, optional Additional keyword arguments for PHATE, by default None @@ -40,10 +43,25 @@ def compute_phate( """ import phate + # Get embeddings from dataset if needed + embeddings = ( + embedding_dataset["features"].values + if isinstance(embedding_dataset, Dataset) + else embedding_dataset + ) + + # Compute PHATE embeddings phate_model = phate.PHATE( n_components=n_components, knn=knn, decay=decay, **phate_kwargs ) - phate_embedding = phate_model.fit_transform(embedding_dataset["features"].values) + phate_embedding = phate_model.fit_transform(embeddings) + + # Update dataset if requested + if update_dataset and isinstance(embedding_dataset, Dataset): + for i in range( + min(2, phate_embedding.shape[1]) + ): # Only update PHATE1 and PHATE2 + embedding_dataset[f"PHATE{i+1}"].values = phate_embedding[:, i] return phate_model, phate_embedding diff --git a/viscy/scripts/recompute_phate.py b/viscy/scripts/recompute_phate.py index 41133c97..fccd5f11 100644 --- a/viscy/scripts/recompute_phate.py +++ b/viscy/scripts/recompute_phate.py @@ -4,9 +4,7 @@ from xarray import open_zarr -from viscy.representation.evaluation.dimensionality_reduction import ( - _fit_transform_phate, -) +from viscy.representation.evaluation.dimensionality_reduction import compute_phate _logger = logging.getLogger(__name__) @@ -34,14 +32,9 @@ def update_phate_embeddings( """ # Load dataset dataset = open_zarr(dataset_path, mode="r+") - features = dataset["features"].values - # Compute new PHATE embeddings + # Compute new PHATE embeddings and update dataset _logger.info(f"Computing PHATE embeddings with parameters: {phate_kwargs}") - _, phate = _fit_transform_phate(features, **phate_kwargs) - - # Update PHATE coordinates - dataset["PHATE1"].values = phate[:, 0] - dataset["PHATE2"].values = phate[:, 1] + compute_phate(dataset, update_dataset=True, **phate_kwargs) _logger.info(f"Updated PHATE embeddings in {dataset_path}") From 745911ae9a26c69e7ae82c925fbfafa8e563d9a6 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Fri, 20 Dec 2024 16:31:49 -0800 Subject: [PATCH 42/44] modularizing the smoothness --- .../cosine_dissimilarity_dataset.py | 255 ++++++++++++------ 1 file changed, 177 insertions(+), 78 deletions(-) diff --git a/applications/contrastive_phenotyping/evaluation/cosine_dissimilarity_dataset.py b/applications/contrastive_phenotyping/evaluation/cosine_dissimilarity_dataset.py index 021eb432..781b4e04 100644 --- a/applications/contrastive_phenotyping/evaluation/cosine_dissimilarity_dataset.py +++ b/applications/contrastive_phenotyping/evaluation/cosine_dissimilarity_dataset.py @@ -1,5 +1,6 @@ # %% from pathlib import Path +from typing import Optional import matplotlib.pyplot as plt import seaborn as sns @@ -17,11 +18,11 @@ from tqdm import tqdm import pandas as pd +from scipy.stats import gaussian_kde +from scipy.optimize import minimize_scalar -plt.style.use("../evaluation/figure.mplstyle") -# plotting -VERBOSE = False +plt.style.use("../evaluation/figure.mplstyle") def compute_piece_wise_dissimilarity( @@ -63,20 +64,57 @@ def plot_histogram( plt.show() -# %% -PATH_TO_GDRIVE_FIGUE = "/home/eduardo.hirata/mydata/gdrive/publications/learning_impacts_of_infection/fig_manuscript/rev2_ICLR_fig/" +def find_distribution_peak(data: np.ndarray) -> float: + """ + Find the peak (mode) of a distribution using kernel density estimation. -prediction_path_1 = Path( - "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/rev6_NTXent_sensorPhase_infection/2chan_160patch_98ckpt_rev6_2.zarr" -) -prediction_path_2 = Path( - "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/rev5_sensorPhase_infection/2chan_160patch_97ckpt_rev5_2.zarr" -) + Args: + data: Array of values to find the peak for -for prediction_path, loss_name in tqdm( - [(prediction_path_1, "ntxent"), (prediction_path_2, "triplet")] -): + Returns: + float: The x-value where the peak occurs + """ + kde = gaussian_kde(data) + # Find the peak (maximum) of the KDE + result = minimize_scalar( + lambda x: -kde(x), bounds=(np.min(data), np.max(data)), method="bounded" + ) + return result.x + + +def analyze_embedding_smoothness( + prediction_path: Path, + verbose: bool = False, + output_path: Optional[str] = None, + loss_name: Optional[str] = None, + overwrite: bool = False, +) -> dict: + """ + Analyze the smoothness and dynamic range of embeddings. + Args: + prediction_path: Path to the embedding dataset + verbose: If True, generates additional plots + output_path: Path to save the final plot (optional) + loss_name: Name of the loss function used (optional) + overwrite: If True, overwrites existing files. If False, raises error if file exists (default: False) + + Returns: + dict: Dictionary containing metrics including: + - dissimilarity_mean: Mean of adjacent frame dissimilarity + - dissimilarity_std: Standard deviation of adjacent frame dissimilarity + - dissimilarity_median: Median of adjacent frame dissimilarity + - dissimilarity_peak: Peak of adjacent frame distribution + - dissimilarity_p99: 99th percentile of adjacent frame dissimilarity + - dissimilarity_p1: 1st percentile of adjacent frame dissimilarity + - dissimilarity_distribution: Full distribution of adjacent frame dissimilarities + - random_mean: Mean of random sampling dissimilarity + - random_std: Standard deviation of random sampling dissimilarity + - random_median: Median of random sampling dissimilarity + - random_peak: Peak of random sampling distribution + - random_distribution: Full distribution of random sampling dissimilarities + - dynamic_range: Difference between random and adjacent peaks + """ # Read the dataset embeddings = read_embedding_dataset(prediction_path) features = embeddings["features"] @@ -86,10 +124,6 @@ def plot_histogram( cross_dist = cross_dissimilarity(scaled_features, metric="cosine") rank_fractions = rank_nearest_neighbors(cross_dist, normalize=True) - plt.figure() - plt.imshow(cross_dist) - plt.show() - # Compute piece-wise dissimilarity and rank difference features_df = features["sample"].to_dataframe().reset_index(drop=True) piece_wise_dissimilarity_per_track, piece_wise_rank_difference_per_track = ( @@ -98,10 +132,6 @@ def plot_histogram( all_dissimilarity = np.concatenate(piece_wise_dissimilarity_per_track) - # # Get the median/mode of the off diagonal elements - # median_piece_wise_dissimilarity = np.array( - # [np.median(track) for track in piece_wise_dissimilarity_per_track] - # ) p99_piece_wise_dissimilarity = np.array( [np.percentile(track, 99) for track in piece_wise_dissimilarity_per_track] ) @@ -109,80 +139,149 @@ def plot_histogram( [np.percentile(track, 1) for track in piece_wise_dissimilarity_per_track] ) - # Random sampling values in the dissimilarity matrix - n_samples = 3000 + # Random sampling values in the dissimilarity matrix with same size as adjacent frame measurements + n_samples = len(all_dissimilarity) random_indices = np.random.randint(0, len(cross_dist), size=(n_samples, 2)) sampled_values = cross_dist[random_indices[:, 0], random_indices[:, 1]] - print(f"Dissimilarity Statistics for {prediction_path.stem}") - print(f"Mean: {np.mean(all_dissimilarity)}") - print(f"Std: {np.std(all_dissimilarity)}") - print(f"Median: {np.median(all_dissimilarity)}") + # Compute the peaks of both distributions using KDE + adjacent_peak = float(find_distribution_peak(all_dissimilarity)) + random_peak = float(find_distribution_peak(sampled_values)) + dynamic_range = float(random_peak - adjacent_peak) - print(f"Distance Statistics for random sampling") - print(f"Mean: {np.mean(sampled_values)}") - print(f"Std: {np.std(sampled_values)}") - print(f"Median: {np.median(sampled_values)}") + metrics = { + "dissimilarity_mean": float(np.mean(all_dissimilarity)), + "dissimilarity_std": float(np.std(all_dissimilarity)), + "dissimilarity_median": float(np.median(all_dissimilarity)), + "dissimilarity_peak": adjacent_peak, + "dissimilarity_p99": p99_piece_wise_dissimilarity, + "dissimilarity_p1": p1_percentile_piece_wise_dissimilarity, + "dissimilarity_distribution": all_dissimilarity, + "random_mean": float(np.mean(sampled_values)), + "random_std": float(np.std(sampled_values)), + "random_median": float(np.median(sampled_values)), + "random_peak": random_peak, + "random_distribution": sampled_values, + "dynamic_range": dynamic_range, + } - if VERBOSE: - # Plot histograms - # plot_histogram( - # median_piece_wise_dissimilarity, - # "Adjacent Frame Median Dissimilarity per Track", - # "Cosine Dissimilarity", - # "Frequency", - # ) - # plot_histogram( - # p1_percentile_piece_wise_dissimilarity, - # "Adjacent Frame 1 Percentile Dissimilarity per Track", - # "Cosine Dissimilarity", - # "Frequency", - # ) - # plot_histogram( - # p99_piece_wise_dissimilarity, - # "Adjacent Frame 99 Percentile Dissimilarity per Track", - # "Cosine Dissimilarity", - # "Frequency", - # ) + if verbose: + # Plot cross distance matrix + plt.figure() + plt.imshow(cross_dist) + plt.show() + # Plot histograms plot_histogram( piece_wise_dissimilarity_per_track, "Adjacent Frame Dissimilarity per Track", "Cosine Dissimilarity", "Frequency", ) - # Plot the histogram of the sampled values - plot_histogram( - sampled_values, - "Random Sampling of Dissimilarity Values", - "Cosine Dissimilarity", - "Density", + + # Plot the comparison histogram and save if output_path is provided + fig = plt.figure() + sns.histplot( + metrics["dissimilarity_distribution"], + bins=30, + kde=True, + color="cyan", + alpha=0.5, + stat="density", + ) + sns.histplot( + metrics["random_distribution"], + bins=30, + kde=True, color="red", alpha=0.5, stat="density", ) + plt.xlabel("Cosine Dissimilarity") + plt.ylabel("Density") + # Add vertical lines for the peaks + plt.axvline( + x=metrics["dissimilarity_peak"], color="cyan", linestyle="--", alpha=0.8 + ) + plt.axvline(x=metrics["random_peak"], color="red", linestyle="--", alpha=0.8) + plt.tight_layout() + plt.legend(["Adjacent Frame", "Random Sample", "Adjacent Peak", "Random Peak"]) - # Plot the median and the random sampling in one plot each with different colors - fig = plt.figure() - sns.histplot( - all_dissimilarity, - bins=30, - kde=True, - color="cyan", - alpha=0.5, - stat="density", - ) - sns.histplot( - sampled_values, bins=30, kde=True, color="red", alpha=0.5, stat="density" + if output_path and loss_name: + output_file = Path( + f"{output_path}/cosine_dissimilarity_smoothness_{prediction_path.stem}_{loss_name}.pdf" + ) + if output_file.exists() and not overwrite: + raise FileExistsError( + f"File {output_file} already exists and overwrite=False" + ) + fig.savefig( + output_file, + dpi=600, + ) + plt.show() + + return metrics + + +# Example usage: +if __name__ == "__main__": + # plotting + VERBOSE = True + + PATH_TO_GDRIVE_FIGUE = "./" + + prediction_path_1 = Path( + "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/rev6_NTXent_sensorPhase_infection/2chan_160patch_98ckpt_rev6_2.zarr" ) - plt.xlabel("Cosine Dissimilarity") - plt.ylabel("Density") - plt.tight_layout() - plt.legend(["Adjacent Frame", "Random Sample"]) - plt.show() - fig.savefig( - f"{PATH_TO_GDRIVE_FIGUE}/cosine_dissimilarity_smoothness_{prediction_path.stem}_{loss_name}.pdf", - dpi=600, + prediction_path_2 = Path( + "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/rev5_sensorPhase_infection/2chan_160patch_97ckpt_rev5_2.zarr" ) + # Create a list of models to evaluate + models = [ + (prediction_path_1, "ntxent"), + (prediction_path_2, "triplet"), + ] + + # Evaluate each model + for prediction_path, loss_name in tqdm(models, desc="Evaluating models"): + print(f"\nAnalyzing model: {prediction_path.stem} (Loss: {loss_name})") + print("-" * 80) + + metrics = analyze_embedding_smoothness( + prediction_path, + verbose=VERBOSE, + output_path=PATH_TO_GDRIVE_FIGUE, + loss_name=loss_name, + overwrite=True, + ) + + # Print adjacent frame dissimilarity statistics + print("\nAdjacent Frame Dissimilarity Statistics:") + print(f"{'Mean:':<15} {metrics['dissimilarity_mean']:.3f}") + print(f"{'Std:':<15} {metrics['dissimilarity_std']:.3f}") + print(f"{'Median:':<15} {metrics['dissimilarity_median']:.3f}") + print(f"{'Peak:':<15} {metrics['dissimilarity_peak']:.3f}") + print(f"{'P1:':<15} {np.mean(metrics['dissimilarity_p1']):.3f}") + print(f"{'P99:':<15} {np.mean(metrics['dissimilarity_p99']):.3f}") + + # Print random sampling statistics + print("\nRandom Sampling Statistics:") + print(f"{'Mean:':<15} {metrics['random_mean']:.3f}") + print(f"{'Std:':<15} {metrics['random_std']:.3f}") + print(f"{'Median:':<15} {metrics['random_median']:.3f}") + print(f"{'Peak:':<15} {metrics['random_peak']:.3f}") + + # Print dynamic range + print("\nComparison Metrics:") + print(f"{'Dynamic Range:':<15} {metrics['dynamic_range']:.3f}") + + # Print distribution sizes + print("\nDistribution Sizes:") + print( + f"{'Adjacent Frame:':<15} {len(metrics['dissimilarity_distribution']):,d} samples" + ) + print(f"{'Random:':<15} {len(metrics['random_distribution']):,d} samples") + # %% From 965abbf7170ca2c102579f1ad92a18b0a9b46445 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Sat, 21 Dec 2024 14:28:07 -0800 Subject: [PATCH 43/44] removing redundant _fit_phate() --- viscy/representation/embedding_writer.py | 4 ++-- .../evaluation/dimensionality_reduction.py | 18 +----------------- 2 files changed, 3 insertions(+), 19 deletions(-) diff --git a/viscy/representation/embedding_writer.py b/viscy/representation/embedding_writer.py index f4bb9c7c..d43ad735 100644 --- a/viscy/representation/embedding_writer.py +++ b/viscy/representation/embedding_writer.py @@ -12,7 +12,7 @@ from viscy.data.triplet import INDEX_COLUMNS from viscy.representation.engine import ContrastivePrediction from viscy.representation.evaluation.dimensionality_reduction import ( - _fit_transform_phate, + compute_phate, _fit_transform_umap, ) @@ -119,7 +119,7 @@ def write_on_epoch_end( f"Computing dimensionality reductions for {len(features)} samples." ) _, umap = _fit_transform_umap(features, n_components=2, normalize=True) - _, phate = _fit_transform_phate( + _, phate = compute_phate( features, **self.phate_kwargs, ) diff --git a/viscy/representation/evaluation/dimensionality_reduction.py b/viscy/representation/evaluation/dimensionality_reduction.py index a6645fe7..6a058ac7 100644 --- a/viscy/representation/evaluation/dimensionality_reduction.py +++ b/viscy/representation/evaluation/dimensionality_reduction.py @@ -11,7 +11,7 @@ def compute_phate( embedding_dataset, - n_components: int = None, + n_components: int = 2, knn: int = 5, decay: int = 40, update_dataset: bool = False, @@ -117,22 +117,6 @@ def _fit_transform_umap( return umap_model, umap_embedding -def _fit_transform_phate( - embeddings: NDArray, - n_components: int = 2, - knn: int = 5, - decay: int = 40, - n_jobs: int = -1, - **phate_kwargs, -) -> tuple[phate.PHATE, NDArray]: - """Fit PHATE model and transform embeddings.""" - phate_model = phate.PHATE( - n_components=n_components, knn=knn, decay=decay, n_jobs=n_jobs, **phate_kwargs - ) - phate_embedding = phate_model.fit_transform(embeddings) - return phate_model, phate_embedding - - def compute_umap( embedding_dataset: Dataset, normalize_features: bool = True ) -> tuple[umap.UMAP, umap.UMAP, pd.DataFrame]: From 3f01a9699aa36c222c2ce130b2b86fe85c2dc0e6 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Sat, 21 Dec 2024 14:28:58 -0800 Subject: [PATCH 44/44] ruff --- viscy/representation/embedding_writer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/viscy/representation/embedding_writer.py b/viscy/representation/embedding_writer.py index d43ad735..c188c739 100644 --- a/viscy/representation/embedding_writer.py +++ b/viscy/representation/embedding_writer.py @@ -12,8 +12,8 @@ from viscy.data.triplet import INDEX_COLUMNS from viscy.representation.engine import ContrastivePrediction from viscy.representation.evaluation.dimensionality_reduction import ( - compute_phate, _fit_transform_umap, + compute_phate, ) __all__ = ["read_embedding_dataset", "EmbeddingWriter"]