diff --git a/applications/contrastive_phenotyping/evaluation/ALFI_displacement.py b/applications/contrastive_phenotyping/evaluation/ALFI_displacement.py new file mode 100644 index 00000000..595f283f --- /dev/null +++ b/applications/contrastive_phenotyping/evaluation/ALFI_displacement.py @@ -0,0 +1,312 @@ +# %% +from pathlib import Path +import matplotlib.pyplot as plt +import numpy as np +from viscy.representation.embedding_writer import read_embedding_dataset +from viscy.representation.evaluation.distance import ( + calculate_normalized_euclidean_distance_cell, + compute_displacement, + compute_dynamic_range, + compute_rms_per_track, +) +from collections import defaultdict +from tabulate import tabulate + +import numpy as np +from sklearn.metrics.pairwise import cosine_similarity +from collections import OrderedDict + +# %% function + +# Removed redundant compute_displacement_mean_std_full function +# Removed redundant compute_dynamic_range and compute_rms_per_track functions + + +def plot_rms_histogram(rms_values, label, bins=30): + """ + Plot histogram of RMS values across tracks. + + Parameters: + rms_values : list + List of RMS values, one for each track. + label : str + Label for the dataset (used in the title). + bins : int, optional + Number of bins for the histogram. Default is 30. + + Returns: + None: Displays the histogram. + """ + plt.figure(figsize=(10, 6)) + plt.hist(rms_values, bins=bins, alpha=0.7, color="blue", edgecolor="black") + plt.title(f"Histogram of RMS Values Across Tracks ({label})", fontsize=16) + plt.xlabel("RMS of Time Derivative", fontsize=14) + plt.ylabel("Frequency", fontsize=14) + plt.grid(True) + plt.show() + + +def plot_displacement( + mean_displacement, std_displacement, label, metrics_no_track=None +): + """ + Plot embedding displacement over time with mean and standard deviation. + + Parameters: + mean_displacement : dict + Mean displacement for each tau. + std_displacement : dict + Standard deviation of displacement for each tau. + label : str + Label for the dataset. + metrics_no_track : dict, optional + Metrics for the "Classical Contrastive (No Tracking)" dataset to compare against. + + Returns: + None: Displays the plot. + """ + plt.figure(figsize=(10, 6)) + taus = list(mean_displacement.keys()) + mean_values = list(mean_displacement.values()) + std_values = list(std_displacement.values()) + + plt.plot(taus, mean_values, marker="o", label=f"{label}", color="green") + plt.fill_between( + taus, + np.array(mean_values) - np.array(std_values), + np.array(mean_values) + np.array(std_values), + color="green", + alpha=0.3, + label=f"Std Dev ({label})", + ) + + if metrics_no_track: + mean_values_no_track = list(metrics_no_track["mean_displacement"].values()) + std_values_no_track = list(metrics_no_track["std_displacement"].values()) + + plt.plot( + taus, + mean_values_no_track, + marker="o", + label="Classical Contrastive (No Tracking)", + color="blue", + ) + plt.fill_between( + taus, + np.array(mean_values_no_track) - np.array(std_values_no_track), + np.array(mean_values_no_track) + np.array(std_values_no_track), + color="blue", + alpha=0.3, + label="Std Dev (No Tracking)", + ) + + plt.xlabel("Time Shift (τ)", fontsize=14) + plt.ylabel("Euclidean Distance", fontsize=14) + plt.title(f"Embedding Displacement Over Time ({label})", fontsize=16) + plt.grid(True) + plt.legend(fontsize=12) + plt.show() + + +def plot_overlay_displacement(overlay_displacement_data): + """ + Plot embedding displacement over time for all datasets in one plot. + + Parameters: + overlay_displacement_data : dict + A dictionary containing mean displacement per tau for all datasets. + + Returns: + None: Displays the plot. + """ + plt.figure(figsize=(12, 8)) + for label, mean_displacement in overlay_displacement_data.items(): + taus = list(mean_displacement.keys()) + mean_values = list(mean_displacement.values()) + plt.plot(taus, mean_values, marker="o", label=label) + + plt.xlabel("Time Shift (τ)", fontsize=14) + plt.ylabel("Euclidean Distance", fontsize=14) + plt.title("Overlayed Embedding Displacement Over Time", fontsize=16) + plt.grid(True) + plt.legend(fontsize=12) + plt.show() + + +# %% hist stats +def plot_boxplot_rms_across_models(datasets_rms): + """ + Plot a boxplot for the distribution of RMS values across models. + + Parameters: + datasets_rms : dict + A dictionary where keys are dataset names and values are lists of RMS values. + + Returns: + None: Displays the boxplot. + """ + plt.figure(figsize=(12, 6)) + labels = list(datasets_rms.keys()) + data = list(datasets_rms.values()) + print(labels) + print(data) + # Plot the boxplot + plt.boxplot(data, tick_labels=labels, patch_artist=True, showmeans=True) + + plt.title( + "Distribution of RMS of Rate of Change of Embedding Across Models", fontsize=16 + ) + plt.ylabel("RMS of Time Derivative", fontsize=14) + plt.xticks(rotation=45, fontsize=12) + plt.grid(axis="y", linestyle="--", alpha=0.7) + plt.tight_layout() + plt.show() + + +def plot_histogram_absolute_differences(datasets_abs_diff): + """ + Plot histograms of absolute differences across embeddings for all models. + + Parameters: + datasets_abs_diff : dict + A dictionary where keys are dataset names and values are lists of absolute differences. + + Returns: + None: Displays the histograms. + """ + plt.figure(figsize=(12, 6)) + for label, abs_diff in datasets_abs_diff.items(): + plt.hist(abs_diff, bins=50, alpha=0.5, label=label, density=True) + + plt.title("Histograms of Absolute Differences Across Models", fontsize=16) + plt.xlabel("Absolute Difference", fontsize=14) + plt.ylabel("Density", fontsize=14) + plt.legend(fontsize=12) + plt.grid(alpha=0.7) + plt.tight_layout() + plt.show() + + +# %% Paths to datasets +feature_paths = { + "7 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_7mins.zarr", + "21 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_21mins.zarr", + "28 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_updated_28mins.zarr", + "56 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_56mins.zarr", + "Cell Aware": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_cellaware.zarr", +} + +no_track_path = "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_classical.zarr" + +# %% Process Datasets +max_tau = 69 +metrics = {} + +overlay_displacement_data = {} +datasets_rms = {} +datasets_abs_diff = {} + +# Process "No Tracking" dataset +features_path_no_track = Path(no_track_path) +embedding_dataset_no_track = read_embedding_dataset(features_path_no_track) + +mean_displacement_no_track, std_displacement_no_track = compute_displacement( + embedding_dataset_no_track, max_tau=max_tau, return_mean_std=True +) +dynamic_range_no_track = compute_dynamic_range(mean_displacement_no_track) +metrics["No Tracking"] = { + "dynamic_range": dynamic_range_no_track, + "mean_displacement": mean_displacement_no_track, + "std_displacement": std_displacement_no_track, +} + +overlay_displacement_data["No Tracking"] = mean_displacement_no_track + +print("\nProcessing No Tracking dataset...") +print(f"Dynamic Range for No Tracking: {dynamic_range_no_track}") + +plot_displacement(mean_displacement_no_track, std_displacement_no_track, "No Tracking") + +rms_values_no_track = compute_rms_per_track(embedding_dataset_no_track) +datasets_rms["No Tracking"] = rms_values_no_track + +print(f"Plotting histogram of RMS values for No Tracking dataset...") +plot_rms_histogram(rms_values_no_track, "No Tracking", bins=30) + +# Compute absolute differences for "No Tracking" +abs_diff_no_track = np.concatenate( + [ + np.linalg.norm( + np.diff(embedding_dataset_no_track["features"].values[indices], axis=0), + axis=-1, + ) + for indices in np.split( + np.arange(len(embedding_dataset_no_track["track_id"])), + np.where(np.diff(embedding_dataset_no_track["track_id"]) != 0)[0] + 1, + ) + ] +) +datasets_abs_diff["No Tracking"] = abs_diff_no_track + +# Process other datasets +for label, path in feature_paths.items(): + print(f"\nProcessing {label} dataset...") + + features_path = Path(path) + embedding_dataset = read_embedding_dataset(features_path) + + mean_displacement, std_displacement = compute_displacement( + embedding_dataset, max_tau=max_tau, return_mean_std=True + ) + dynamic_range = compute_dynamic_range(mean_displacement) + metrics[label] = { + "dynamic_range": dynamic_range, + "mean_displacement": mean_displacement, + "std_displacement": std_displacement, + } + + overlay_displacement_data[label] = mean_displacement + + print(f"Dynamic Range for {label}: {dynamic_range}") + + plot_displacement( + mean_displacement, + std_displacement, + label, + metrics_no_track=metrics.get("No Tracking", None), + ) + + rms_values = compute_rms_per_track(embedding_dataset) + datasets_rms[label] = rms_values + + print(f"Plotting histogram of RMS values for {label}...") + plot_rms_histogram(rms_values, label, bins=30) + + abs_diff = np.concatenate( + [ + np.linalg.norm( + np.diff(embedding_dataset["features"].values[indices], axis=0), axis=-1 + ) + for indices in np.split( + np.arange(len(embedding_dataset["track_id"])), + np.where(np.diff(embedding_dataset["track_id"]) != 0)[0] + 1, + ) + ] + ) + datasets_abs_diff[label] = abs_diff + +print("\nPlotting overlayed displacement for all datasets...") +plot_overlay_displacement(overlay_displacement_data) + +print("\nSummary of Dynamic Ranges:") +for label, metric in metrics.items(): + print(f"{label}: Dynamic Range = {metric['dynamic_range']}") + +print("\nPlotting RMS boxplot across models...") +plot_boxplot_rms_across_models(datasets_rms) + +print("\nPlotting histograms of absolute differences across models...") +plot_histogram_absolute_differences(datasets_abs_diff) + + +# %% diff --git a/applications/contrastive_phenotyping/evaluation/cosine_dissimilarity_dataset.py b/applications/contrastive_phenotyping/evaluation/cosine_dissimilarity_dataset.py new file mode 100644 index 00000000..781b4e04 --- /dev/null +++ b/applications/contrastive_phenotyping/evaluation/cosine_dissimilarity_dataset.py @@ -0,0 +1,287 @@ +# %% +from pathlib import Path +from typing import Optional + +import matplotlib.pyplot as plt +import seaborn as sns +from sklearn.preprocessing import StandardScaler +from numpy.typing import NDArray + +from viscy.representation.embedding_writer import read_embedding_dataset +from viscy.representation.evaluation.clustering import ( + compare_time_offset, + cross_dissimilarity, + rank_nearest_neighbors, + select_block, +) +import numpy as np +from tqdm import tqdm +import pandas as pd + +from scipy.stats import gaussian_kde +from scipy.optimize import minimize_scalar + + +plt.style.use("../evaluation/figure.mplstyle") + + +def compute_piece_wise_dissimilarity( + features_df: pd.DataFrame, cross_dist: NDArray, rank_fractions: NDArray +): + """ + Computing the smoothness and dynamic range + - Get the off diagonal per block and compute the mode + - The blocks are not square, so we need to get the off diagonal elements + - Get the 1 and 99 percentile of the off diagonal per block + """ + piece_wise_dissimilarity_per_track = [] + piece_wise_rank_difference_per_track = [] + for name, subdata in features_df.groupby(["fov_name", "track_id"]): + if len(subdata) > 1: + indices = subdata.index.values + single_track_dissimilarity = select_block(cross_dist, indices) + single_track_rank_fraction = select_block(rank_fractions, indices) + piece_wise_dissimilarity = compare_time_offset( + single_track_dissimilarity, time_offset=1 + ) + piece_wise_rank_difference = compare_time_offset( + single_track_rank_fraction, time_offset=1 + ) + piece_wise_dissimilarity_per_track.append(piece_wise_dissimilarity) + piece_wise_rank_difference_per_track.append(piece_wise_rank_difference) + return piece_wise_dissimilarity_per_track, piece_wise_rank_difference_per_track + + +def plot_histogram( + data, title, xlabel, ylabel, color="blue", alpha=0.5, stat="frequency" +): + plt.figure() + plt.title(title) + sns.histplot(data, bins=30, kde=True, color=color, alpha=alpha, stat=stat) + plt.xlabel(xlabel) + plt.ylabel(ylabel) + plt.tight_layout() + plt.show() + + +def find_distribution_peak(data: np.ndarray) -> float: + """ + Find the peak (mode) of a distribution using kernel density estimation. + + Args: + data: Array of values to find the peak for + + Returns: + float: The x-value where the peak occurs + """ + kde = gaussian_kde(data) + # Find the peak (maximum) of the KDE + result = minimize_scalar( + lambda x: -kde(x), bounds=(np.min(data), np.max(data)), method="bounded" + ) + return result.x + + +def analyze_embedding_smoothness( + prediction_path: Path, + verbose: bool = False, + output_path: Optional[str] = None, + loss_name: Optional[str] = None, + overwrite: bool = False, +) -> dict: + """ + Analyze the smoothness and dynamic range of embeddings. + + Args: + prediction_path: Path to the embedding dataset + verbose: If True, generates additional plots + output_path: Path to save the final plot (optional) + loss_name: Name of the loss function used (optional) + overwrite: If True, overwrites existing files. If False, raises error if file exists (default: False) + + Returns: + dict: Dictionary containing metrics including: + - dissimilarity_mean: Mean of adjacent frame dissimilarity + - dissimilarity_std: Standard deviation of adjacent frame dissimilarity + - dissimilarity_median: Median of adjacent frame dissimilarity + - dissimilarity_peak: Peak of adjacent frame distribution + - dissimilarity_p99: 99th percentile of adjacent frame dissimilarity + - dissimilarity_p1: 1st percentile of adjacent frame dissimilarity + - dissimilarity_distribution: Full distribution of adjacent frame dissimilarities + - random_mean: Mean of random sampling dissimilarity + - random_std: Standard deviation of random sampling dissimilarity + - random_median: Median of random sampling dissimilarity + - random_peak: Peak of random sampling distribution + - random_distribution: Full distribution of random sampling dissimilarities + - dynamic_range: Difference between random and adjacent peaks + """ + # Read the dataset + embeddings = read_embedding_dataset(prediction_path) + features = embeddings["features"] + + scaled_features = StandardScaler().fit_transform(features.values) + # Compute the cosine dissimilarity + cross_dist = cross_dissimilarity(scaled_features, metric="cosine") + rank_fractions = rank_nearest_neighbors(cross_dist, normalize=True) + + # Compute piece-wise dissimilarity and rank difference + features_df = features["sample"].to_dataframe().reset_index(drop=True) + piece_wise_dissimilarity_per_track, piece_wise_rank_difference_per_track = ( + compute_piece_wise_dissimilarity(features_df, cross_dist, rank_fractions) + ) + + all_dissimilarity = np.concatenate(piece_wise_dissimilarity_per_track) + + p99_piece_wise_dissimilarity = np.array( + [np.percentile(track, 99) for track in piece_wise_dissimilarity_per_track] + ) + p1_percentile_piece_wise_dissimilarity = np.array( + [np.percentile(track, 1) for track in piece_wise_dissimilarity_per_track] + ) + + # Random sampling values in the dissimilarity matrix with same size as adjacent frame measurements + n_samples = len(all_dissimilarity) + random_indices = np.random.randint(0, len(cross_dist), size=(n_samples, 2)) + sampled_values = cross_dist[random_indices[:, 0], random_indices[:, 1]] + + # Compute the peaks of both distributions using KDE + adjacent_peak = float(find_distribution_peak(all_dissimilarity)) + random_peak = float(find_distribution_peak(sampled_values)) + dynamic_range = float(random_peak - adjacent_peak) + + metrics = { + "dissimilarity_mean": float(np.mean(all_dissimilarity)), + "dissimilarity_std": float(np.std(all_dissimilarity)), + "dissimilarity_median": float(np.median(all_dissimilarity)), + "dissimilarity_peak": adjacent_peak, + "dissimilarity_p99": p99_piece_wise_dissimilarity, + "dissimilarity_p1": p1_percentile_piece_wise_dissimilarity, + "dissimilarity_distribution": all_dissimilarity, + "random_mean": float(np.mean(sampled_values)), + "random_std": float(np.std(sampled_values)), + "random_median": float(np.median(sampled_values)), + "random_peak": random_peak, + "random_distribution": sampled_values, + "dynamic_range": dynamic_range, + } + + if verbose: + # Plot cross distance matrix + plt.figure() + plt.imshow(cross_dist) + plt.show() + + # Plot histograms + plot_histogram( + piece_wise_dissimilarity_per_track, + "Adjacent Frame Dissimilarity per Track", + "Cosine Dissimilarity", + "Frequency", + ) + + # Plot the comparison histogram and save if output_path is provided + fig = plt.figure() + sns.histplot( + metrics["dissimilarity_distribution"], + bins=30, + kde=True, + color="cyan", + alpha=0.5, + stat="density", + ) + sns.histplot( + metrics["random_distribution"], + bins=30, + kde=True, + color="red", + alpha=0.5, + stat="density", + ) + plt.xlabel("Cosine Dissimilarity") + plt.ylabel("Density") + # Add vertical lines for the peaks + plt.axvline( + x=metrics["dissimilarity_peak"], color="cyan", linestyle="--", alpha=0.8 + ) + plt.axvline(x=metrics["random_peak"], color="red", linestyle="--", alpha=0.8) + plt.tight_layout() + plt.legend(["Adjacent Frame", "Random Sample", "Adjacent Peak", "Random Peak"]) + + if output_path and loss_name: + output_file = Path( + f"{output_path}/cosine_dissimilarity_smoothness_{prediction_path.stem}_{loss_name}.pdf" + ) + if output_file.exists() and not overwrite: + raise FileExistsError( + f"File {output_file} already exists and overwrite=False" + ) + fig.savefig( + output_file, + dpi=600, + ) + plt.show() + + return metrics + + +# Example usage: +if __name__ == "__main__": + # plotting + VERBOSE = True + + PATH_TO_GDRIVE_FIGUE = "./" + + prediction_path_1 = Path( + "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/rev6_NTXent_sensorPhase_infection/2chan_160patch_98ckpt_rev6_2.zarr" + ) + prediction_path_2 = Path( + "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/rev5_sensorPhase_infection/2chan_160patch_97ckpt_rev5_2.zarr" + ) + + # Create a list of models to evaluate + models = [ + (prediction_path_1, "ntxent"), + (prediction_path_2, "triplet"), + ] + + # Evaluate each model + for prediction_path, loss_name in tqdm(models, desc="Evaluating models"): + print(f"\nAnalyzing model: {prediction_path.stem} (Loss: {loss_name})") + print("-" * 80) + + metrics = analyze_embedding_smoothness( + prediction_path, + verbose=VERBOSE, + output_path=PATH_TO_GDRIVE_FIGUE, + loss_name=loss_name, + overwrite=True, + ) + + # Print adjacent frame dissimilarity statistics + print("\nAdjacent Frame Dissimilarity Statistics:") + print(f"{'Mean:':<15} {metrics['dissimilarity_mean']:.3f}") + print(f"{'Std:':<15} {metrics['dissimilarity_std']:.3f}") + print(f"{'Median:':<15} {metrics['dissimilarity_median']:.3f}") + print(f"{'Peak:':<15} {metrics['dissimilarity_peak']:.3f}") + print(f"{'P1:':<15} {np.mean(metrics['dissimilarity_p1']):.3f}") + print(f"{'P99:':<15} {np.mean(metrics['dissimilarity_p99']):.3f}") + + # Print random sampling statistics + print("\nRandom Sampling Statistics:") + print(f"{'Mean:':<15} {metrics['random_mean']:.3f}") + print(f"{'Std:':<15} {metrics['random_std']:.3f}") + print(f"{'Median:':<15} {metrics['random_median']:.3f}") + print(f"{'Peak:':<15} {metrics['random_peak']:.3f}") + + # Print dynamic range + print("\nComparison Metrics:") + print(f"{'Dynamic Range:':<15} {metrics['dynamic_range']:.3f}") + + # Print distribution sizes + print("\nDistribution Sizes:") + print( + f"{'Adjacent Frame:':<15} {len(metrics['dissimilarity_distribution']):,d} samples" + ) + print(f"{'Random:':<15} {len(metrics['random_distribution']):,d} samples") + +# %% diff --git a/applications/contrastive_phenotyping/examples_cli/predict.yml b/applications/contrastive_phenotyping/examples_cli/predict.yml index 4dd218cf..622f273f 100644 --- a/applications/contrastive_phenotyping/examples_cli/predict.yml +++ b/applications/contrastive_phenotyping/examples_cli/predict.yml @@ -8,7 +8,12 @@ trainer: callbacks: - class_path: viscy.representation.embedding_writer.EmbeddingWriter init_args: - output_path: "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/contrastive_tune_augmentations/predict/test_prediction_code.zarr" + output_path: "/path/to/output.zarr" + phate_kwargs: + n_components: 2 + knn: 10 + decay: 50 + gamma: 1 # edit the following lines to specify logging path # - class_path: lightning.pytorch.loggers.TensorBoardLogger # init_args: diff --git a/pyproject.toml b/pyproject.toml index 32f196c8..e4eb3a07 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ metrics = [ "ptflops>=0.7", "umap-learn", "captum>=0.7.0", + "phate", ] examples = ["napari", "jupyter", "jupytext"] visual = [ diff --git a/viscy/representation/embedding_writer.py b/viscy/representation/embedding_writer.py index 4f324047..c188c739 100644 --- a/viscy/representation/embedding_writer.py +++ b/viscy/representation/embedding_writer.py @@ -13,6 +13,7 @@ from viscy.representation.engine import ContrastivePrediction from viscy.representation.evaluation.dimensionality_reduction import ( _fit_transform_umap, + compute_phate, ) __all__ = ["read_embedding_dataset", "EmbeddingWriter"] @@ -53,16 +54,38 @@ class EmbeddingWriter(BasePredictionWriter): Path to the zarr store. write_interval : Literal["batch", "epoch", "batch_and_epoch"], optional When to write the embeddings, by default 'epoch'. + phate_kwargs : dict, optional + Keyword arguments passed to PHATE, by default None. + Common parameters include: + - knn: int, number of nearest neighbors (default: 5) + - decay: int, decay rate for kernel (default: 40) + - n_jobs: int, number of jobs for parallel processing + - t: int, number of diffusion steps + - potential_method: str, potential method to use + See phate.PHATE for all available parameters. """ def __init__( self, output_path: Path, write_interval: Literal["batch", "epoch", "batch_and_epoch"] = "epoch", + phate_kwargs: dict | None = None, ): super().__init__(write_interval) self.output_path = Path(output_path) + # Set default PHATE parameters + default_phate_kwargs = { + "n_components": 2, + "knn": 5, + "decay": 40, + "n_jobs": -1, + "random_state": 42, + } + if phate_kwargs is not None: + default_phate_kwargs.update(phate_kwargs) + self.phate_kwargs = default_phate_kwargs + def on_predict_start(self, trainer: Trainer, pl_module: LightningModule) -> None: if self.output_path.exists(): raise FileExistsError(f"Output path {self.output_path} already exists.") @@ -75,7 +98,7 @@ def write_on_epoch_end( predictions: Sequence[ContrastivePrediction], batch_indices: Sequence[int], ) -> None: - """Write predictions and the 2-component UMAP of features to a zarr store. + """Write predictions and dimensionality reductions to a zarr store. Parameters ---------- @@ -91,10 +114,21 @@ def write_on_epoch_end( features = _move_and_stack_embeddings(predictions, "features") projections = _move_and_stack_embeddings(predictions, "projections") ultrack_indices = pd.concat([pd.DataFrame(p["index"]) for p in predictions]) - _logger.info(f"Computing UMAP embeddings for {len(features)} samples.") + + _logger.info( + f"Computing dimensionality reductions for {len(features)} samples." + ) _, umap = _fit_transform_umap(features, n_components=2, normalize=True) - ultrack_indices["UMAP1"] = umap[:, 0] - ultrack_indices["UMAP2"] = umap[:, 1] + _, phate = compute_phate( + features, + **self.phate_kwargs, + ) + + # Add dimensionality reduction coordinates + ultrack_indices["UMAP1"], ultrack_indices["UMAP2"] = umap[:, 0], umap[:, 1] + ultrack_indices["PHATE1"], ultrack_indices["PHATE2"] = phate[:, 0], phate[:, 1] + + # Create multi-index and dataset index = pd.MultiIndex.from_frame(ultrack_indices) dataset = Dataset( { @@ -103,6 +137,7 @@ def write_on_epoch_end( }, coords={"sample": index}, ).reset_index("sample") - _logger.debug(f"Wrtiting predictions dataset:\n{dataset}") - zarr_store = dataset.to_zarr(self.output_path, mode="w") - zarr_store.close() + + _logger.debug(f"Writing predictions dataset:\n{dataset}") + with dataset.to_zarr(self.output_path, mode="w") as zarr_store: + zarr_store.close() diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index 481df87b..ca67263a 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -7,6 +7,7 @@ from lightning.pytorch import LightningModule from pytorch_metric_learning.losses import NTXentLoss from torch import Tensor, nn +from umap import UMAP from viscy.data.typing import TrackingIndex, TripletSample from viscy.representation.contrastive import ContrastiveEncoder @@ -34,6 +35,7 @@ def __init__( schedule: Literal["WarmupCosine", "Constant"] = "Constant", log_batches_per_epoch: int = 8, log_samples_per_batch: int = 1, + log_embeddings: bool = False, example_input_array_shape: Sequence[int] = (1, 2, 15, 256, 256), ) -> None: super().__init__() @@ -46,6 +48,7 @@ def __init__( self.example_input_array = torch.rand(*example_input_array_shape) self.training_step_outputs = [] self.validation_step_outputs = [] + self.log_embeddings = log_embeddings def forward(self, x: Tensor) -> Tensor: "Only return projected embeddings for training and validation." @@ -118,6 +121,19 @@ def _log_step_samples(self, batch_idx, samples, stage: Literal["train", "val"]): ) output_list.extend(detach_sample(samples, self.log_samples_per_batch)) + def log_embedding_umap(self, embeddings: Tensor, tag: str): + _logger.debug(f"Computing UMAP for {tag} embeddings.") + umap = UMAP(n_components=2) + embeddings_np = embeddings.detach().cpu().numpy() + umap_embeddings = umap.fit_transform(embeddings_np) + + # Log UMAP embeddings to TensorBoard + self.logger.experiment.add_embedding( + umap_embeddings, + global_step=self.current_epoch, + tag=f"{tag}_umap", + ) + def training_step(self, batch: TripletSample, batch_idx: int) -> Tensor: anchor_img = batch["anchor"] pos_img = batch["positive"] @@ -152,6 +168,12 @@ def training_step(self, batch: TripletSample, batch_idx: int) -> Tensor: def on_train_epoch_end(self) -> None: super().on_train_epoch_end() self._log_samples("train_samples", self.training_step_outputs) + # Log UMAP embeddings for validation + if self.log_embeddings: + embeddings = torch.cat( + [output["embeddings"] for output in self.validation_step_outputs] + ) + self.log_embedding_umap(embeddings, tag="train") self.training_step_outputs = [] def validation_step(self, batch: TripletSample, batch_idx: int) -> Tensor: @@ -189,6 +211,13 @@ def validation_step(self, batch: TripletSample, batch_idx: int) -> Tensor: def on_validation_epoch_end(self) -> None: super().on_validation_epoch_end() self._log_samples("val_samples", self.validation_step_outputs) + # Log UMAP embeddings for training + if self.log_embeddings: + embeddings = torch.cat( + [output["embeddings"] for output in self.training_step_outputs] + ) + self.log_embedding_umap(embeddings, tag="val") + self.validation_step_outputs = [] def configure_optimizers(self): diff --git a/viscy/representation/evaluation/dimensionality_reduction.py b/viscy/representation/evaluation/dimensionality_reduction.py index 0a906bf4..6a058ac7 100644 --- a/viscy/representation/evaluation/dimensionality_reduction.py +++ b/viscy/representation/evaluation/dimensionality_reduction.py @@ -1,6 +1,7 @@ """PCA and UMAP dimensionality reduction.""" import pandas as pd +import phate import umap from numpy.typing import NDArray from sklearn.decomposition import PCA @@ -8,6 +9,63 @@ from xarray import Dataset +def compute_phate( + embedding_dataset, + n_components: int = 2, + knn: int = 5, + decay: int = 40, + update_dataset: bool = False, + **phate_kwargs, +) -> tuple[phate.PHATE, NDArray]: + """ + Compute PHATE embeddings for features and optionally update dataset. + + Parameters + ---------- + embedding_dataset : xarray.Dataset or NDArray + The dataset containing embeddings, timepoints, fov_name, and track_id, + or a numpy array of embeddings. + n_components : int, optional + Number of dimensions in the PHATE embedding, by default None + knn : int, optional + Number of nearest neighbors to use in the KNN graph, by default 5 + decay : int, optional + Decay parameter for the Markov operator, by default 40 + update_dataset : bool, optional + Whether to update the PHATE coordinates in the dataset, by default False + phate_kwargs : dict, optional + Additional keyword arguments for PHATE, by default None + + Returns + ------- + phate.PHATE, NDArray + PHATE model and PHATE embeddings + """ + import phate + + # Get embeddings from dataset if needed + embeddings = ( + embedding_dataset["features"].values + if isinstance(embedding_dataset, Dataset) + else embedding_dataset + ) + + # Compute PHATE embeddings + phate_model = phate.PHATE( + n_components=n_components, knn=knn, decay=decay, **phate_kwargs + ) + phate_embedding = phate_model.fit_transform(embeddings) + + # Update dataset if requested + if update_dataset and isinstance(embedding_dataset, Dataset): + for i in range( + min(2, phate_embedding.shape[1]) + ): # Only update PHATE1 and PHATE2 + embedding_dataset[f"PHATE{i+1}"].values = phate_embedding[:, i] + + return phate_model, phate_embedding + + def compute_pca(embedding_dataset, n_components=None, normalize_features=True): features = embedding_dataset["features"] projections = embedding_dataset["projections"] diff --git a/viscy/representation/evaluation/distance.py b/viscy/representation/evaluation/distance.py index 5c8f7759..9a1c72ef 100644 --- a/viscy/representation/evaluation/distance.py +++ b/viscy/representation/evaluation/distance.py @@ -6,95 +6,18 @@ def calculate_cosine_similarity_cell(embedding_dataset, fov_name, track_id): """Extract embeddings and calculate cosine similarities for a specific cell""" - # Filter the dataset for the specific infected cell filtered_data = embedding_dataset.where( (embedding_dataset["fov_name"] == fov_name) & (embedding_dataset["track_id"] == track_id), drop=True, ) - - # Extract the feature embeddings and time points features = filtered_data["features"].values # (sample, features) time_points = filtered_data["t"].values # (sample,) - - # Get the first time point's embedding first_time_point_embedding = features[0].reshape(1, -1) - - # Calculate cosine similarity between each time point and the first time point - cosine_similarities = [] - for i in range(len(time_points)): - similarity = cosine_similarity( - first_time_point_embedding, features[i].reshape(1, -1) - ) - cosine_similarities.append(similarity[0][0]) - - return time_points, cosine_similarities - - -def compute_displacement_mean_std( - embedding_dataset, max_tau=10, use_cosine=False, use_dissimilarity=False -): - """Compute the norm of differences between embeddings at t and t + tau""" - # Get the arrays of (fov_name, track_id, t, and embeddings) - fov_names = embedding_dataset["fov_name"].values - track_ids = embedding_dataset["track_id"].values - timepoints = embedding_dataset["t"].values - embeddings = embedding_dataset["features"].values - - # Dictionary to store displacements for each tau - displacement_per_tau = defaultdict(list) - - # Iterate over all entries in the dataset - for i in range(len(fov_names)): - fov_name = fov_names[i] - track_id = track_ids[i] - current_time = timepoints[i] - current_embedding = embeddings[i] - - # For each time point t, compute displacements for t + tau - for tau in range(1, max_tau + 1): - future_time = current_time + tau - - # Find if future_time exists for the same (fov_name, track_id) - matching_indices = np.where( - (fov_names == fov_name) - & (track_ids == track_id) - & (timepoints == future_time) - )[0] - - if len(matching_indices) == 1: - # Get the embedding at t + tau - future_embedding = embeddings[matching_indices[0]] - - if use_cosine: - # Compute cosine similarity - similarity = cosine_similarity( - current_embedding.reshape(1, -1), - future_embedding.reshape(1, -1), - )[0][0] - # Choose whether to use similarity or dissimilarity - if use_dissimilarity: - displacement = 1 - similarity # Cosine dissimilarity - else: - displacement = similarity # Cosine similarity - else: - # Compute the Euclidean distance, elementwise square on difference - displacement = np.sum((current_embedding - future_embedding) ** 2) - - # Store the displacement for the given tau - displacement_per_tau[tau].append(displacement) - - # Compute mean and std displacement for each tau by averaging the displacements - mean_displacement_per_tau = { - tau: np.mean(displacements) - for tau, displacements in displacement_per_tau.items() - } - std_displacement_per_tau = { - tau: np.std(displacements) - for tau, displacements in displacement_per_tau.items() - } - - return mean_displacement_per_tau, std_displacement_per_tau + cosine_similarities = cosine_similarity( + first_time_point_embedding, features + ).flatten() + return time_points, cosine_similarities.tolist() def compute_displacement( @@ -103,35 +26,30 @@ def compute_displacement( use_cosine=False, use_dissimilarity=False, use_umap=False, + return_mean_std=False, ): """Compute the norm of differences between embeddings at t and t + tau""" - # Get the arrays of (fov_name, track_id, t, and embeddings) fov_names = embedding_dataset["fov_name"].values track_ids = embedding_dataset["track_id"].values timepoints = embedding_dataset["t"].values if use_umap: - umap1 = embedding_dataset["UMAP1"].values - umap2 = embedding_dataset["UMAP2"].values - embeddings = np.vstack((umap1, umap2)).T + embeddings = np.vstack( + (embedding_dataset["UMAP1"].values, embedding_dataset["UMAP2"].values) + ).T else: embeddings = embedding_dataset["features"].values - # Dictionary to store displacements for each tau displacement_per_tau = defaultdict(list) - # Iterate over all entries in the dataset for i in range(len(fov_names)): fov_name = fov_names[i] track_id = track_ids[i] current_time = timepoints[i] - current_embedding = embeddings[i] + current_embedding = embeddings[i].reshape(1, -1) - # For each time point t, compute displacements for t + tau for tau in range(1, max_tau + 1): future_time = current_time + tau - - # Find if future_time exists for the same (fov_name, track_id) matching_indices = np.where( (fov_names == fov_name) & (track_ids == track_id) @@ -139,56 +57,58 @@ def compute_displacement( )[0] if len(matching_indices) == 1: - # Get the embedding at t + tau - future_embedding = embeddings[matching_indices[0]] + future_embedding = embeddings[matching_indices[0]].reshape(1, -1) if use_cosine: - # Compute cosine similarity - similarity = cosine_similarity( - current_embedding.reshape(1, -1), - future_embedding.reshape(1, -1), - )[0][0] - # Choose whether to use similarity or dissimilarity - if use_dissimilarity: - displacement = 1 - similarity # Cosine dissimilarity - else: - displacement = similarity # Cosine similarity + similarity = cosine_similarity(current_embedding, future_embedding)[ + 0 + ][0] + displacement = 1 - similarity if use_dissimilarity else similarity else: - # Compute the Euclidean distance, elementwise square on difference displacement = np.sum((current_embedding - future_embedding) ** 2) - # Store the displacement for the given tau displacement_per_tau[tau].append(displacement) - return displacement_per_tau + if return_mean_std: + mean_displacement_per_tau = { + tau: np.mean(displacements) + for tau, displacements in displacement_per_tau.items() + } + std_displacement_per_tau = { + tau: np.std(displacements) + for tau, displacements in displacement_per_tau.items() + } + return mean_displacement_per_tau, std_displacement_per_tau + return displacement_per_tau -def calculate_normalized_euclidean_distance_cell(embedding_dataset, fov_name, track_id): - filtered_data = embedding_dataset.where( - (embedding_dataset["fov_name"] == fov_name) - & (embedding_dataset["track_id"] == track_id), - drop=True, - ) - features = filtered_data["features"].values # (sample, features) - time_points = filtered_data["t"].values # (sample,) +def compute_dynamic_range(mean_displacement_per_tau): + """ + Compute the dynamic range as the difference between the maximum + and minimum mean displacement per τ. - normalized_features = features / np.linalg.norm(features, axis=1, keepdims=True) + Parameters: + mean_displacement_per_tau: dict with τ as key and mean displacement as value - # Get the first time point's normalized embedding - first_time_point_embedding = normalized_features[0].reshape(1, -1) + Returns: + float: dynamic range (max displacement - min displacement) + """ + displacements = list(mean_displacement_per_tau.values()) + return max(displacements) - min(displacements) - euclidean_distances = [] - for i in range(len(time_points)): - distance = np.linalg.norm( - first_time_point_embedding - normalized_features[i].reshape(1, -1) - ) - euclidean_distances.append(distance) - return time_points, euclidean_distances +def compute_rms_per_track(embedding_dataset): + """ + Compute RMS of the time derivative of embeddings per track. + Parameters: + embedding_dataset : xarray.Dataset + The dataset containing embeddings, timepoints, fov_name, and track_id. -def compute_displacement_mean_std_full(embedding_dataset, max_tau=10): + Returns: + list: A list of RMS values, one for each track. + """ fov_names = embedding_dataset["fov_name"].values track_ids = embedding_dataset["track_id"].values timepoints = embedding_dataset["t"].values @@ -198,52 +118,45 @@ def compute_displacement_mean_std_full(embedding_dataset, max_tau=10): list(zip(fov_names, track_ids)), dtype=[("fov_name", "O"), ("track_id", "int64")], ) - unique_cells = np.unique(cell_identifiers) - displacement_per_tau = defaultdict(list) + rms_values = [] for cell in unique_cells: fov_name = cell["fov_name"] track_id = cell["track_id"] - indices = np.where((fov_names == fov_name) & (track_ids == track_id))[0] - cell_timepoints = timepoints[indices] cell_embeddings = embeddings[indices] + if len(cell_embeddings) < 2: + continue + sorted_indices = np.argsort(cell_timepoints) - cell_timepoints = cell_timepoints[sorted_indices] cell_embeddings = cell_embeddings[sorted_indices] + differences = np.diff(cell_embeddings, axis=0) - for i in range(len(cell_timepoints)): - current_time = cell_timepoints[i] - current_embedding = cell_embeddings[i] - - current_embedding = current_embedding / np.linalg.norm(current_embedding) - - for tau in range(0, max_tau + 1): - future_time = current_time + tau - - future_index = np.where(cell_timepoints == future_time)[0] + if differences.shape[0] == 0: + continue - if len(future_index) >= 1: - future_embedding = cell_embeddings[future_index[0]] - future_embedding = future_embedding / np.linalg.norm( - future_embedding - ) + norms = np.linalg.norm(differences, axis=1) + rms = np.sqrt(np.mean(norms**2)) + rms_values.append(rms) - distance = np.linalg.norm(current_embedding - future_embedding) + return rms_values - displacement_per_tau[tau].append(distance) - mean_displacement_per_tau = { - tau: np.mean(displacements) - for tau, displacements in displacement_per_tau.items() - } - std_displacement_per_tau = { - tau: np.std(displacements) - for tau, displacements in displacement_per_tau.items() - } - - return mean_displacement_per_tau, std_displacement_per_tau +def calculate_normalized_euclidean_distance_cell(embedding_dataset, fov_name, track_id): + filtered_data = embedding_dataset.where( + (embedding_dataset["fov_name"] == fov_name) + & (embedding_dataset["track_id"] == track_id), + drop=True, + ) + features = filtered_data["features"].values # (sample, features) + time_points = filtered_data["t"].values # (sample,) + normalized_features = features / np.linalg.norm(features, axis=1, keepdims=True) + first_time_point_embedding = normalized_features[0].reshape(1, -1) + euclidean_distances = np.linalg.norm( + first_time_point_embedding - normalized_features, axis=1 + ) + return time_points, euclidean_distances.tolist() diff --git a/viscy/scripts/recompute_phate.py b/viscy/scripts/recompute_phate.py new file mode 100644 index 00000000..fccd5f11 --- /dev/null +++ b/viscy/scripts/recompute_phate.py @@ -0,0 +1,40 @@ +import logging +from pathlib import Path +from typing import Any + +from xarray import open_zarr + +from viscy.representation.evaluation.dimensionality_reduction import compute_phate + +_logger = logging.getLogger(__name__) + + +def update_phate_embeddings( + dataset_path: Path, + phate_kwargs: dict[str, Any], +) -> None: + """ + Update PHATE embeddings in an existing dataset with new parameters. + + Parameters + ---------- + dataset_path : Path + Path to the zarr store containing embeddings + phate_kwargs : dict + New PHATE parameters to use for recomputing embeddings. + Common parameters include: + - n_components: int, number of dimensions (default: 2) + - knn: int, number of nearest neighbors (default: 5) + - decay: int, decay rate for kernel (default: 40) + - n_jobs: int, number of jobs for parallel processing + - t: int, number of diffusion steps + - gamma: float, gamma parameter for kernel + """ + # Load dataset + dataset = open_zarr(dataset_path, mode="r+") + + # Compute new PHATE embeddings and update dataset + _logger.info(f"Computing PHATE embeddings with parameters: {phate_kwargs}") + compute_phate(dataset, update_dataset=True, **phate_kwargs) + + _logger.info(f"Updated PHATE embeddings in {dataset_path}")