diff --git a/applications/contrastive_phenotyping/evaluation/PCA_UMAP_evaluation.py b/applications/contrastive_phenotyping/evaluation/PCA_UMAP_evaluation.py new file mode 100644 index 00000000..2933832f --- /dev/null +++ b/applications/contrastive_phenotyping/evaluation/PCA_UMAP_evaluation.py @@ -0,0 +1,181 @@ +# %% +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import plotly.express as px +import seaborn as sns +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler +from umap import UMAP +from sklearn.decomposition import PCA + +from viscy.representation.embedding_writer import read_embedding_dataset +from viscy.representation.evaluation import dataset_of_tracks, load_annotation + +# %% Paths and parameters. + +features_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr" +) +data_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr" +) +tracks_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr" +) + +# %% +embedding_dataset = read_embedding_dataset(features_path) +embedding_dataset + +# %% +# Compute UMAP over all features +features = embedding_dataset["features"] +# or select a well: +# features = features[features["fov_name"].str.contains("B/4")] + +scaled_features = StandardScaler().fit_transform(features.values) +umap = UMAP() +# Fit UMAP on all features +embedding = umap.fit_transform(scaled_features) + +# %% +# Add UMAP coordinates to the dataset and plot w/ time + +features = ( + features.assign_coords(UMAP1=("sample", embedding[:, 0])) + .assign_coords(UMAP2=("sample", embedding[:, 1])) + .set_index(sample=["UMAP1", "UMAP2"], append=True) +) +features + + +sns.scatterplot( + x=features["UMAP1"], y=features["UMAP2"], hue=features["t"], s=7, alpha=0.8 +) + +# Add the title to the plot +plt.title("Cell & Time Aware (30 min interval)") +plt.savefig('umap_cell_time_aware_time.svg', format='svg') +plt.savefig('umap_cell_time_aware_time.pdf', format='pdf') +# Show the plot +plt.show() + +# %% + +any_features_path = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_difcell_randomtime_sampling/Ver2_updateTracking_refineModel/predictions/Feb_2chan_128patch_32projDim/2chan_128patch_56ckpt_FebTest.zarr") +embedding_dataset = read_embedding_dataset(any_features_path) +embedding_dataset + +# %% +# Compute UMAP over all features +features = embedding_dataset["features"] +# or select a well: +# features = features[features["fov_name"].str.contains("B/4")] + +scaled_features = StandardScaler().fit_transform(features.values) +umap = UMAP() +# Fit UMAP on all features +embedding = umap.fit_transform(scaled_features) + +# %% Any time sampling plot + +features = ( + features.assign_coords(UMAP1=("sample", embedding[:, 0])) + .assign_coords(UMAP2=("sample", embedding[:, 1])) + .set_index(sample=["UMAP1", "UMAP2"], append=True) +) +features + + +sns.scatterplot( + x=features["UMAP1"], y=features["UMAP2"], hue=features["t"], s=7, alpha=0.8 +) + +# Add the title to the plot +plt.title("Cell Aware Any Time") +plt.savefig('umap_cell_aware_time.png', format='png') +#plt.savefig('umap_cell_aware_time.pdf', format='pdf') +# Show the plot +plt.show() + +# %% + +contrastive_learning_path = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/feb_fixed_test_predict.zarr") +embedding_dataset = read_embedding_dataset(contrastive_learning_path) +embedding_dataset + +# %% +# Compute UMAP over all features +features = embedding_dataset["features"] +# or select a well: +# features = features[features["fov_name"].str.contains("B/4")] + +scaled_features = StandardScaler().fit_transform(features.values) +umap = UMAP() +# Fit UMAP on all features +embedding = umap.fit_transform(scaled_features) + +# %% Any time sampling plot + +features = ( + features.assign_coords(UMAP1=("sample", embedding[:, 0])) + .assign_coords(UMAP2=("sample", embedding[:, 1])) + .set_index(sample=["UMAP1", "UMAP2"], append=True) +) +features + + +sns.scatterplot( + x=features["UMAP1"], y=features["UMAP2"], hue=features["t"], s=7, alpha=0.8 +) + +# Add the title to the plot +plt.title("Classical Contrastive Learning") +plt.savefig('classical_time.svg', format='svg') +plt.savefig('classical_time.pdf', format='pdf') + +# Show the plot +plt.show() + +# %% PCA + +pca = PCA(n_components=4) +# scaled_features = StandardScaler().fit_transform(features.values) +# pca_features = pca.fit_transform(scaled_features) +pca_features = pca.fit_transform(features.values) + +features = ( + features.assign_coords(PCA1=("sample", pca_features[:, 0])) + .assign_coords(PCA2=("sample", pca_features[:, 1])) + .assign_coords(PCA3=("sample", pca_features[:, 2])) + .assign_coords(PCA4=("sample", pca_features[:, 3])) + .set_index(sample=["PCA1", "PCA2", "PCA3", "PCA4"], append=True) +) + +# %% plot PCA components w/ time + +plt.figure(figsize=(10, 10)) +sns.scatterplot(x=features["PCA1"], y=features["PCA2"], hue=features["t"], s=7, alpha=0.8) + +# %% OVERLAY INFECTION ANNOTATION +ann_root = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/supervised_inf_pred" +) + +infection = load_annotation( + features, + ann_root / "extracted_inf_state.csv", + "infection_state", + {0.0: "background", 1.0: "uninfected", 2.0: "infected"}, +) + +# %% +sns.scatterplot(x=features["UMAP1"], y=features["UMAP2"], hue=infection, s=7, alpha=0.8) + +# %% plot PCA components with infection hue +sns.scatterplot(x=features["PCA1"], y=features["PCA2"], hue=infection, s=7, alpha=0.8) + +# %% diff --git a/applications/contrastive_phenotyping/evaluation/cosine_similarity.py b/applications/contrastive_phenotyping/evaluation/cosine_similarity.py new file mode 100644 index 00000000..325f11c7 --- /dev/null +++ b/applications/contrastive_phenotyping/evaluation/cosine_similarity.py @@ -0,0 +1,253 @@ +# %% +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import plotly.express as px +import seaborn as sns +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler +from umap import UMAP +from sklearn.decomposition import PCA + +from viscy.representation.embedding_writer import read_embedding_dataset +from viscy.representation.evaluation import dataset_of_tracks, load_annotation +from sklearn.metrics.pairwise import cosine_similarity +from collections import defaultdict +from viscy.representation.evaluation import calculate_cosine_similarity_cell +from viscy.representation.evaluation import compute_displacement_mean_std +from viscy.representation.evaluation import compute_displacement + +# %% Paths and parameters. + +features_path_30_min = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr" +) + +feature_path_no_track = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/feb_fixed_test_predict.zarr") + +features_path_any_time = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_difcell_randomtime_sampling/Ver2_updateTracking_refineModel/predictions/Feb_1chan_128patch_32projDim/1chan_128patch_63ckpt_FebTest.zarr") + +data_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr" +) + +tracks_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr" +) + +# %% Load embedding datasets for all three sampling +fov_name = '/B/4/6' +track_id = 4 + +embedding_dataset_30_min = read_embedding_dataset(features_path_30_min) +embedding_dataset_no_track = read_embedding_dataset(feature_path_no_track) +embedding_dataset_any_time = read_embedding_dataset(features_path_any_time) + +# Calculate cosine similarities for each sampling +time_points_30_min, cosine_similarities_30_min = calculate_cosine_similarity_cell(embedding_dataset_30_min, fov_name, track_id) +time_points_no_track, cosine_similarities_no_track = calculate_cosine_similarity_cell(embedding_dataset_no_track, fov_name, track_id) +time_points_any_time, cosine_similarities_any_time = calculate_cosine_similarity_cell(embedding_dataset_any_time, fov_name, track_id) + +# %% Plot cosine similarities over time for all three conditions + +plt.figure(figsize=(10, 6)) + +plt.plot(time_points_no_track, cosine_similarities_no_track, marker='o', label='classical contrastive (no tracking)') +plt.plot(time_points_any_time, cosine_similarities_any_time, marker='o', label='cell aware') +plt.plot(time_points_30_min, cosine_similarities_30_min, marker='o', label='cell & time aware (interval 30 min)') + + +plt.xlabel("Time Delay (t)") +plt.ylabel("Cosine Similarity with First Time Point") +plt.title("Cosine Similarity Over Time for Infected Cell") + +#plt.savefig('infected_cell_example.svg', format='svg') +#plt.savefig('infected_cell_example.pdf', format='pdf') + +plt.grid(True) +plt.legend() + + +plt.show() +# %% + +# %% import statements + +from pathlib import Path +import numpy as np +import matplotlib.pyplot as plt +from sklearn.metrics.pairwise import euclidean_distances + +from viscy.representation.embedding_writer import read_embedding_dataset +from viscy.representation.evaluation import dataset_of_tracks, load_annotation +from sklearn.metrics.pairwise import cosine_similarity + +# %% Paths to datasets +features_path_30_min = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr") +feature_path_no_track = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/feb_fixed_test_predict.zarr") +#features_path_any_time = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_difcell_randomtime_sampling/Ver2_updateTracking_refineModel/predictions/Feb_1chan_128patch_32projDim/1chan_128patch_63ckpt_FebTest.zarr") + +# %% Read embedding datasets +embedding_dataset_30_min = read_embedding_dataset(features_path_30_min) +embedding_dataset_no_track = read_embedding_dataset(feature_path_no_track) +#embedding_dataset_any_time = read_embedding_dataset(features_path_any_time) + +# %% Compute displacements for both datasets (using Euclidean distance and Cosine similarity) +max_tau = 10 # Maximum time shift (tau) to compute displacements + +mean_displacement_30_min, std_displacement_30_min = compute_displacement_mean_std(embedding_dataset_30_min, max_tau, use_cosine=False, use_dissimilarity=True) +mean_displacement_no_track, std_displacement_no_track = compute_displacement_mean_std(embedding_dataset_no_track, max_tau, use_cosine=False, use_dissimilarity=True) +#mean_displacement_any_time, std_displacement_any_time = compute_displacement_mean_std(embedding_dataset_any_time, max_tau, use_cosine=False) + +mean_displacement_30_min_cosine, std_displacement_30_min_cosine = compute_displacement_mean_std(embedding_dataset_30_min, max_tau, use_cosine=True, use_dissimilarity=True) +mean_displacement_no_track_cosine, std_displacement_no_track_cosine = compute_displacement_mean_std(embedding_dataset_no_track, max_tau, use_cosine=True, use_dissimilarity=True) +#mean_displacement_any_time_cosine, std_displacement_any_time_cosine = compute_displacement_mean_std(embedding_dataset_any_time, max_tau, use_cosine=True) +# %% Plot 1: Euclidean Displacements +plt.figure(figsize=(10, 6)) + +taus = list(mean_displacement_30_min.keys()) +mean_values_30_min = list(mean_displacement_30_min.values()) +std_values_30_min = list(std_displacement_30_min.values()) + +mean_values_no_track = list(mean_displacement_no_track.values()) +std_values_no_track = list(std_displacement_no_track.values()) + +# mean_values_any_time = list(mean_displacement_any_time.values()) +# std_values_any_time = list(std_displacement_any_time.values()) + +# Plotting Euclidean displacements +plt.plot(taus, mean_values_30_min, marker='o', label='Cell & Time Aware (30 min interval)') +plt.fill_between(taus, np.array(mean_values_30_min) - np.array(std_values_30_min), np.array(mean_values_30_min) + np.array(std_values_30_min), + color='gray', alpha=0.3, label='Std Dev (30 min interval)') + +plt.plot(taus, mean_values_no_track, marker='o', label='Classical Contrastive (No Tracking)') +plt.fill_between(taus, np.array(mean_values_no_track) - np.array(std_values_no_track), np.array(mean_values_no_track) + np.array(std_values_no_track), + color='blue', alpha=0.3, label='Std Dev (No Tracking)') + +plt.xlabel('Time Shift (τ)') +plt.ylabel('Displacement') +plt.title('Embedding Displacement Over Time') +plt.grid(True) +plt.legend() + +# plt.savefig('embedding_displacement_euclidean.svg', format='svg') +# plt.savefig('embedding_displacement_euclidean.pdf', format='pdf') + +# Show the Euclidean plot +plt.show() + +# %% Plot 2: Cosine Displacements +plt.figure(figsize=(10, 6)) + +# Plotting Cosine displacements +mean_values_30_min_cosine = list(mean_displacement_30_min_cosine.values()) +std_values_30_min_cosine = list(std_displacement_30_min_cosine.values()) + +mean_values_no_track_cosine = list(mean_displacement_no_track_cosine.values()) +std_values_no_track_cosine = list(std_displacement_no_track_cosine.values()) + +plt.plot(taus, mean_values_30_min_cosine, marker='o', label='Cell & Time Aware (30 min interval)') +plt.fill_between(taus, np.array(mean_values_30_min_cosine) - np.array(std_values_30_min_cosine), np.array(mean_values_30_min_cosine) + np.array(std_values_30_min_cosine), + color='orange', alpha=0.3, label='Std Dev (30 min interval, Cosine)') + +plt.plot(taus, mean_values_no_track_cosine, marker='o', label='Classical Contrastive (No Tracking)') +plt.fill_between(taus, np.array(mean_values_no_track_cosine) - np.array(std_values_no_track_cosine), np.array(mean_values_no_track_cosine) + np.array(std_values_no_track_cosine), + color='red', alpha=0.3, label='Std Dev (No Tracking)') + +plt.xlabel('Time Shift (τ)') +plt.ylabel('Cosine Similarity') +plt.title('Embedding Displacement Over Time') +plt.grid(True) +plt.legend() + +# Show the Cosine plot +plt.show() +# %% + +import seaborn as sns +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +from pathlib import Path +from collections import defaultdict +from sklearn.metrics.pairwise import cosine_similarity +from viscy.representation.embedding_writer import read_embedding_dataset + +# %% Paths to datasets +features_path_30_min = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr") +feature_path_no_track = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/feb_fixed_test_predict.zarr") + +# %% Read embedding datasets +embedding_dataset_30_min = read_embedding_dataset(features_path_30_min) +embedding_dataset_no_track = read_embedding_dataset(feature_path_no_track) + +# %% Compute displacements for both datasets (using Cosine similarity) +max_tau = 10 # Maximum time shift (tau) to compute displacements + +# Compute displacements for Cell & Time Aware (30 min interval) using Cosine similarity +displacement_per_tau_aware_cosine = compute_displacement(embedding_dataset_30_min, max_tau, use_cosine=True, use_dissimilarity=True) + +# Compute displacements for Classical Contrastive (No Tracking) using Cosine similarity +displacement_per_tau_contrastive_cosine = compute_displacement(embedding_dataset_no_track, max_tau, use_cosine=True, use_dissimilarity=True) + +# %% Prepare data for violin plot +# Prepare the data in a long-form DataFrame for the violin plot +def prepare_violin_data(taus, displacement_aware, displacement_contrastive): + # Create a list to hold the data + data = [] + + # Populate the data for Cell & Time Aware + for tau in taus: + displacements_aware = displacement_aware.get(tau, []) + for displacement in displacements_aware: + data.append({'Time Shift (τ)': tau, 'Displacement': displacement, 'Sampling': 'Cell & Time Aware (30 min interval)'}) + + # Populate the data for Classical Contrastive + for tau in taus: + displacements_contrastive = displacement_contrastive.get(tau, []) + for displacement in displacements_contrastive: + data.append({'Time Shift (τ)': tau, 'Displacement': displacement, 'Sampling': 'Classical Contrastive (No Tracking)'}) + + # Convert to a DataFrame + df = pd.DataFrame(data) + return df + +# Assuming 'displacement_per_tau_aware_cosine' and 'displacement_per_tau_contrastive_cosine' hold the displacements as dictionaries +taus = list(displacement_per_tau_aware_cosine.keys()) + +# Prepare the violin plot data +df = prepare_violin_data(taus, displacement_per_tau_aware_cosine, displacement_per_tau_contrastive_cosine) + +# Create a violin plot using seaborn +plt.figure(figsize=(12, 8)) +sns.violinplot( + x='Time Shift (τ)', + y='Displacement', + hue='Sampling', + data=df, + palette='Set2', + scale='width', + bw=0.2, + inner=None, + split=True, + cut=0 +) + +# Add labels and title +plt.xlabel('Time Shift (τ)', fontsize=14) +plt.ylabel('Cosine Dissimilarity', fontsize=14) +plt.title('Cosine Dissimilarity Distribution', fontsize=16) +plt.grid(True, linestyle='--', alpha=0.6) # Lighter grid lines for less distraction +plt.legend(title='Sampling', fontsize=12, title_fontsize=14) + +plt.ylim(0, 0.5) + +# Save the violin plot as SVG and PDF +plt.savefig('violin_plot_cosine_similarity.svg', format='svg') +plt.savefig('violin_plot_cosine_similarity.pdf', format='pdf') + +# Show the plot +plt.show() +# %% diff --git a/applications/contrastive_phenotyping/figures/figure_4a_1.py b/applications/contrastive_phenotyping/figures/figure_4a_1.py index a670db0d..273839ad 100644 --- a/applications/contrastive_phenotyping/figures/figure_4a_1.py +++ b/applications/contrastive_phenotyping/figures/figure_4a_1.py @@ -6,13 +6,13 @@ import seaborn as sns from sklearn.preprocessing import StandardScaler from umap import UMAP - from viscy.representation.embedding_writer import read_embedding_dataset +from sklearn.decomposition import PCA # %% Defining Paths for February and June Datasets -feb_features_path = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/code_testing_soorya/output/June_140Patch_2chan/phaseRFP_140patch_99ckpt_Feb.zarr") -feb_data_path = Path("/hpc/projects/virtual_staining/2024_02_04_A549_DENV_ZIKV_timelapse/registered_chunked.zarr") -feb_tracks_path = Path("/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr") +feb_features_path = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/febtest_predict.zarr") +feb_data_path = Path("/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr") +feb_tracks_path = Path("/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr") # %% Function to Load and Process the Embedding Dataset def compute_umap(embedding_dataset): @@ -31,7 +31,8 @@ def compute_umap(embedding_dataset): # %% Function to Load Annotations def load_annotation(da, path, name, categories: dict | None = None): annotation = pd.read_csv(path) - annotation["fov_name"] = "/" + annotation["fov ID"] + print(annotation.columns) + annotation["fov_name"] = "/" + annotation["fov name "] annotation = annotation.set_index(["fov_name", "id"]) mi = pd.MultiIndex.from_arrays( [da["fov_name"].values, da["id"].values], names=["fov_name", "id"] @@ -52,116 +53,99 @@ def plot_umap_infection(features, infection, title): feb_embedding_dataset = read_embedding_dataset(feb_features_path) feb_features = compute_umap(feb_embedding_dataset) -feb_ann_root = Path("/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track") -feb_infection = load_annotation(feb_features, feb_ann_root / "tracking_v1_infection.csv", "infection class", {0.0: "background", 1.0: "uninfected", 2.0: "infected"}) +feb_ann_root = Path("/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/supervised_inf_pred") +feb_infection = load_annotation(feb_features, feb_ann_root / "extracted_inf_state.csv", "infection_state", {0.0: "background", 1.0: "uninfected", 2.0: "infected"}) # %% Plot UMAP with Infection Status for February Dataset plot_umap_infection(feb_features, feb_infection, "February Dataset") -# %% -print(feb_embedding_dataset) -print(feb_infection) -print(feb_features) -# %% - - -# %% Identify cells by infection type using fov_name -mock_cells = feb_features.sel(sample=feb_features['fov_name'].str.contains('/A/3') | feb_features['fov_name'].str.contains('/B/3')) -zika_cells = feb_features.sel(sample=feb_features['fov_name'].str.contains('/A/4')) -dengue_cells = feb_features.sel(sample=feb_features['fov_name'].str.contains('/B/4')) +# %% Function to Load and Process the Embedding Dataset using PCA +def compute_pc(embedding_dataset): + features = embedding_dataset["features"] + + # Standardize the features + scaled_features = StandardScaler().fit_transform(features.values) + + # Perform PCA + pca = PCA(n_components=2) + pc_embedding = pca.fit_transform(scaled_features) + + # Convert the PCA embedding into an xarray.DataArray, maintaining the same coordinates as the original dataset + pc_data = xr.DataArray( + data=pc_embedding, + dims=["sample", "pc"], + coords={ + "sample": features.coords["sample"], + "fov_name": features.coords["fov_name"], + "track_id": features.coords["track_id"], + "t": features.coords["t"], + "id": features.coords["id"], + "PC1": ("sample", pc_embedding[:, 0]), + "PC2": ("sample", pc_embedding[:, 1]) + } + ) + + return pc_data -# %% Plot UMAP with Infection Status -plt.figure(figsize=(10, 8)) -sns.scatterplot(x=feb_features["UMAP1"], y=feb_features["UMAP2"], hue=feb_infection, s=7, alpha=0.8) -# Overlay with circled cells -plt.scatter(mock_cells["UMAP1"], mock_cells["UMAP2"], facecolors='none', edgecolors='blue', s=20, label='Mock Cells') -plt.scatter(zika_cells["UMAP1"], zika_cells["UMAP2"], facecolors='none', edgecolors='green', s=20, label='Zika MOI 5') -plt.scatter(dengue_cells["UMAP1"], dengue_cells["UMAP2"], facecolors='none', edgecolors='red', s=20, label='Dengue MOI 5') +# %% Function to Plot PCA with Infection Status +def plot_pca_infection(pca_data, infection, title): + plt.figure(figsize=(10, 8)) + sns.scatterplot(x=pca_data["PCA1"], y=pca_data["PCA2"], hue=infection, s=7, alpha=0.8) + plt.title(f"PCA Plot - {title}") + plt.xlabel("PCA1") + plt.ylabel("PCA2") + plt.show() -# Add legend and show plot -plt.legend(loc='best') -plt.title("UMAP Plot - February Dataset with Mock, Zika, and Dengue Highlighted") -plt.show() +# %% Plot PCA for February Dataset +feb_pca_data = compute_pca(feb_embedding_dataset) +print(feb_pca_data) +#plot_pca_infection(feb_pca_data, feb_infection, "February Dataset") # %% -# %% Create a 1x3 grid of heatmaps -fig, axs = plt.subplots(1, 3, figsize=(18, 6), sharex=True, sharey=True) - -# Mock Cells Heatmap -sns.histplot(x=mock_cells["UMAP1"], y=mock_cells["UMAP2"], bins=50, pmax=1, cmap="Blues", ax=axs[0]) -axs[0].set_title('Mock Cells') -axs[0].set_xlim(feb_features["UMAP1"].min(), feb_features["UMAP1"].max()) -axs[0].set_ylim(feb_features["UMAP2"].min(), feb_features["UMAP2"].max()) - -# Zika Cells Heatmap -sns.histplot(x=zika_cells["UMAP1"], y=zika_cells["UMAP2"], bins=50, pmax=1, cmap="Greens", ax=axs[1]) -axs[1].set_title('Zika MOI 5') -axs[1].set_xlim(feb_features["UMAP1"].min(), feb_features["UMAP1"].max()) -axs[1].set_ylim(feb_features["UMAP2"].min(), feb_features["UMAP2"].max()) - -# Dengue Cells Heatmap -sns.histplot(x=dengue_cells["UMAP1"], y=dengue_cells["UMAP2"], bins=50, pmax=1, cmap="Reds", ax=axs[2]) -axs[2].set_title('Dengue MOI 5') -axs[2].set_xlim(feb_features["UMAP1"].min(), feb_features["UMAP1"].max()) -axs[2].set_ylim(feb_features["UMAP2"].min(), feb_features["UMAP2"].max()) - -# Set labels and adjust layout -for ax in axs: - ax.set_xlabel('UMAP1') - ax.set_ylabel('UMAP2') - -plt.tight_layout() -plt.show() - +print("PCA embedding shape: ", pc_embedding.shape) +print("PCA embedding head: ", pc_embedding[:5]) # %% -import matplotlib.pyplot as plt -import seaborn as sns - -# %% Create a 2x3 grid of heatmaps (1 row for each heatmap, splitting infected and uninfected in the second row) -fig, axs = plt.subplots(2, 3, figsize=(24, 12), sharex=True, sharey=True) - -# Mock Cells Heatmap -sns.histplot(x=mock_cells["UMAP1"], y=mock_cells["UMAP2"], bins=50, pmax=1, cmap="Blues", ax=axs[0, 0]) -axs[0, 0].set_title('Mock Cells') -axs[0, 0].set_xlim(feb_features["UMAP1"].min(), feb_features["UMAP1"].max()) -axs[0, 0].set_ylim(feb_features["UMAP2"].min(), feb_features["UMAP2"].max()) - -# Zika Cells Heatmap -sns.histplot(x=zika_cells["UMAP1"], y=zika_cells["UMAP2"], bins=50, pmax=1, cmap="Greens", ax=axs[0, 1]) -axs[0, 1].set_title('Zika MOI 5') -axs[0, 1].set_xlim(feb_features["UMAP1"].min(), feb_features["UMAP1"].max()) -axs[0, 1].set_ylim(feb_features["UMAP2"].min(), feb_features["UMAP2"].max()) -# Dengue Cells Heatmap -sns.histplot(x=dengue_cells["UMAP1"], y=dengue_cells["UMAP2"], bins=50, pmax=1, cmap="Reds", ax=axs[0, 2]) -axs[0, 2].set_title('Dengue MOI 5') -axs[0, 2].set_xlim(feb_features["UMAP1"].min(), feb_features["UMAP1"].max()) -axs[0, 2].set_ylim(feb_features["UMAP2"].min(), feb_features["UMAP2"].max()) +def plot_umap_histogram(umap_data, infection, title): + plt.figure(figsize=(15, 5)) + states = infection.unique() # Get unique infection states + + for i, state in enumerate(states): + plt.subplot(1, len(states), i+1) + + # Align both `fov_name` and `id` by comparing their values + condition = (umap_data.coords["fov_name"].values == infection.coords["fov_name"].values) & \ + (umap_data.coords["id"].values == infection.coords["id"].values) & \ + (infection == state) + + # Filter umap_data based on this condition + subset = umap_data.where(condition, drop=True) + + # Create the hexbin plot for each state + plt.hexbin(subset["UMAP1"].values, subset["UMAP2"].values, gridsize=50, cmap="inferno") + plt.title(f"Infection = {state}") + plt.xlabel("UMAP1") + plt.ylabel("UMAP2") + + # Set the main title + plt.suptitle(f"{title} - UMAP Histogram") + plt.show() -# Infected Cells Heatmap -sns.histplot(x=infected_cells["UMAP1"], y=infected_cells["UMAP2"], bins=50, pmax=1, cmap="Reds", ax=axs[1, 0]) -axs[1, 0].set_title('Infected Cells') -axs[1, 0].set_xlim(feb_features["UMAP1"].min(), feb_features["UMAP1"].max()) -axs[1, 0].set_ylim(feb_features["UMAP2"].min(), feb_features["UMAP2"].max()) -# Uninfected Cells Heatmap -sns.histplot(x=uninfected_cells["UMAP1"], y=uninfected_cells["UMAP2"], bins=50, pmax=1, cmap="Greens", ax=axs[1, 1]) -axs[1, 1].set_title('Uninfected Cells') -axs[1, 1].set_xlim(feb_features["UMAP1"].min(), feb_features["UMAP1"].max()) -axs[1, 1].set_ylim(feb_features["UMAP2"].min(), feb_features["UMAP2"].max()) -# Remove the last subplot (bottom right corner) -fig.delaxes(axs[1, 2]) -# Set labels and adjust layout -for ax in axs.flat: - ax.set_xlabel('UMAP1') - ax.set_ylabel('UMAP2') +# %% +feb_umap_data = compute_umap(feb_embedding_dataset) -plt.tight_layout() -plt.show() +plot_umap_histogram(feb_umap_data, feb_infection, "February Dataset") +# %% +print(feb_umap_data) +print(feb_infection) +plot_umap_histogram(feb_umap_data, feb_infection, "February Dataset") +# %% +print(feb_umap_data.coords["fov_name"]) # %% diff --git a/applications/contrastive_phenotyping/figures/figure_4e_2_feb.py b/applications/contrastive_phenotyping/figures/figure_4e_2_feb.py index d3052018..fb2d5a73 100644 --- a/applications/contrastive_phenotyping/figures/figure_4e_2_feb.py +++ b/applications/contrastive_phenotyping/figures/figure_4e_2_feb.py @@ -3,10 +3,9 @@ import matplotlib.pyplot as plt import pandas as pd - +from pathlib import Path from viscy.representation.embedding_writer import read_embedding_dataset - # %% Function to Load Annotations from GMM CSV def load_gmm_annotation(gmm_csv_path): gmm_df = pd.read_csv(gmm_csv_path) @@ -25,23 +24,24 @@ def count_infected_cell_states_over_time(embedding_dataset, gmm_df): # Merge with GMM data to add GMM labels df = pd.merge(df, gmm_df[['id', 'fov_name', 'Predicted_Label']], on=['fov_name', 'id'], how='left') - # Filter by time range (3 HPI to 30 HPI) + # Filter by time range (3 HPI to 27 HPI) df = df[(df['t'] >= 3) & (df['t'] <= 27)] - # Determine the well type (Mock, Zika, Dengue) based on fov_name - df['well_type'] = df['fov_name'].apply(lambda x: 'Mock' if '/A/3' in x or '/B/3' in x else - ('Zika' if '/A/4' in x else 'Dengue')) + # Determine the well type (Mock and Dengue) based on fov_name + df['well_type'] = df['fov_name'].apply(lambda x: 'Mock' if '/A/3' in x or '/B/3' in x else 'Dengue') - # Group by time, well type, and GMM label to count the number of infected cells + # Group by time, well type, and GMM label to count the number of infected and total cells state_counts = df.groupby(['t', 'well_type', 'Predicted_Label']).size().unstack(fill_value=0) # Ensure that 'infected' column exists if 'infected' not in state_counts.columns: state_counts['infected'] = 0 - # Calculate the percentage of infected cells + # Calculate the total number of cells (including uninfected) in each well type state_counts['total'] = state_counts.sum(axis=1) - state_counts['infected'] = (state_counts['infected'] / state_counts['total']) * 100 + + # Calculate the percentage of infected cells + state_counts['infected_percentage'] = (state_counts['infected'] / state_counts['total']) * 100 return state_counts @@ -49,17 +49,17 @@ def count_infected_cell_states_over_time(embedding_dataset, gmm_df): def plot_infected_cell_states(state_counts): plt.figure(figsize=(12, 8)) - # Loop through each well type - for well_type in ['Mock', 'Zika', 'Dengue']: + # Loop through each well type (Mock, Dengue) + for well_type in ['Mock', 'Dengue']: # Select the data for the current well type if well_type in state_counts.index.get_level_values('well_type'): well_data = state_counts.xs(well_type, level='well_type') # Plot only the percentage of infected cells - if 'infected' in well_data.columns: - plt.plot(well_data.index, well_data['infected'], label=f'{well_type} - Infected') + if 'infected_percentage' in well_data.columns: + plt.plot(well_data.index, well_data['infected_percentage'], label=f'{well_type} - Infected') - plt.title("Percentage of Infected Cells Over Time - February") + plt.title("Percentage of Infected Cells Over Time - Mock and Dengue Wells") plt.xlabel("Hours Post Perturbation") plt.ylabel("Percentage of Infected Cells") plt.legend(title="Well Type") @@ -67,11 +67,11 @@ def plot_infected_cell_states(state_counts): plt.show() # %% Load and process Feb Dataset -feb_features_path = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/code_testing_soorya/output/June_140Patch_2chan/phaseRFP_140patch_99ckpt_Feb.zarr") +feb_features_path = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/febtest_predict.zarr") feb_embedding_dataset = read_embedding_dataset(feb_features_path) # Load the GMM annotation CSV -gmm_csv_path = "june_logistic_regression_predicted_labels_feb_pca.csv" # Path to CSV file +gmm_csv_path = "feb_test_regression_predicted_labels_embedding.csv" # Path to CSV file gmm_df = load_gmm_annotation(gmm_csv_path) # %% Count Infected Cell States Over Time as Percentage using GMM labels @@ -84,4 +84,3 @@ def plot_infected_cell_states(state_counts): # %% - diff --git a/viscy/representation/evaluation.py b/viscy/representation/evaluation.py index cbf8ead0..d1e3830b 100644 --- a/viscy/representation/evaluation.py +++ b/viscy/representation/evaluation.py @@ -15,8 +15,9 @@ ) from sklearn.neighbors import KNeighborsClassifier from sklearn.preprocessing import StandardScaler - +from sklearn.metrics.pairwise import cosine_similarity from viscy.data.triplet import TripletDataModule +from collections import defaultdict """ This module enables evaluation of learned representations using annotations, such as @@ -379,3 +380,223 @@ def compute_std_dev(image): std_dev = np.std(image) return std_dev + +# Function to extract embeddings and calculate cosine similarities for a specific cell +def calculate_cosine_similarity_cell(embedding_dataset, fov_name, track_id): + # Filter the dataset for the specific infected cell + filtered_data = embedding_dataset.where( + (embedding_dataset['fov_name'] == fov_name) & + (embedding_dataset['track_id'] == track_id), + drop=True + ) + + # Extract the feature embeddings and time points + features = filtered_data['features'].values # (sample, features) + time_points = filtered_data['t'].values # (sample,) + + # Get the first time point's embedding + first_time_point_embedding = features[0].reshape(1, -1) + + # Calculate cosine similarity between each time point and the first time point + cosine_similarities = [] + for i in range(len(time_points)): + similarity = cosine_similarity( + first_time_point_embedding, features[i].reshape(1, -1) + ) + cosine_similarities.append(similarity[0][0]) + + return time_points, cosine_similarities + +# # Function to compute the norm of differences between embeddings at t and t + tau +# def compute_displacement_mean_std(embedding_dataset, max_tau=10, use_cosine=False): +# # Get the arrays of (fov_name, track_id, t, and embeddings) +# fov_names = embedding_dataset['fov_name'].values +# track_ids = embedding_dataset['track_id'].values +# timepoints = embedding_dataset['t'].values +# embeddings = embedding_dataset['features'].values + +# # Dictionary to store displacements for each tau +# displacement_per_tau = defaultdict(list) + +# # Iterate over all entries in the dataset +# for i in range(len(fov_names)): +# fov_name = fov_names[i] +# track_id = track_ids[i] +# current_time = timepoints[i] +# current_embedding = embeddings[i] + +# # For each time point t, compute displacements for t + tau +# for tau in range(1, max_tau + 1): +# future_time = current_time + tau + +# # Find if future_time exists for the same (fov_name, track_id) +# matching_indices = np.where( +# (fov_names == fov_name) & (track_ids == track_id) & (timepoints == future_time) +# )[0] + +# if len(matching_indices) == 1: +# # Get the embedding at t + tau +# future_embedding = embeddings[matching_indices[0]] + +# if use_cosine: +# # Compute cosine similarity +# similarity = cosine_similarity(current_embedding.reshape(1, -1), future_embedding.reshape(1, -1))[0][0] +# displacement = 1 - similarity # Cosine dissimilarity (1 - cosine similarity) +# else: +# # Compute the Euclidean distance, elementwise square on difference +# displacement = np.sum((current_embedding - future_embedding) ** 2) + +# # Store the displacement for the given tau +# displacement_per_tau[tau].append(displacement) + +# # Compute mean and std displacement for each tau by averaging the displacements +# mean_displacement_per_tau = {tau: np.mean(displacements) for tau, displacements in displacement_per_tau.items()} +# std_displacement_per_tau = {tau: np.std(displacements) for tau, displacements in displacement_per_tau.items()} + +# return mean_displacement_per_tau, std_displacement_per_tau + +# # Function to compute the norm of differences between embeddings at t and t + tau +# def compute_displacement(embedding_dataset, max_tau=10, use_cosine=False): +# # Get the arrays of (fov_name, track_id, t, and embeddings) +# fov_names = embedding_dataset['fov_name'].values +# track_ids = embedding_dataset['track_id'].values +# timepoints = embedding_dataset['t'].values +# embeddings = embedding_dataset['features'].values + +# # Dictionary to store displacements for each tau +# displacement_per_tau = defaultdict(list) + +# # Iterate over all entries in the dataset +# for i in range(len(fov_names)): +# fov_name = fov_names[i] +# track_id = track_ids[i] +# current_time = timepoints[i] +# current_embedding = embeddings[i] + +# # For each time point t, compute displacements for t + tau +# for tau in range(1, max_tau + 1): +# future_time = current_time + tau + +# # Find if future_time exists for the same (fov_name, track_id) +# matching_indices = np.where( +# (fov_names == fov_name) & (track_ids == track_id) & (timepoints == future_time) +# )[0] + +# if len(matching_indices) == 1: +# # Get the embedding at t + tau +# future_embedding = embeddings[matching_indices[0]] + +# if use_cosine: +# # Compute cosine similarity +# similarity = cosine_similarity(current_embedding.reshape(1, -1), future_embedding.reshape(1, -1))[0][0] +# displacement = 1 - similarity # Cosine dissimilarity (1 - cosine similarity) +# else: +# # Compute the Euclidean distance, elementwise square on difference +# displacement = np.sum((current_embedding - future_embedding) ** 2) + +# # Store the displacement for the given tau +# displacement_per_tau[tau].append(displacement) + +# return displacement_per_tau + +# Function to compute the norm of differences between embeddings at t and t + tau +def compute_displacement_mean_std(embedding_dataset, max_tau=10, use_cosine=False, use_dissimilarity=False): + # Get the arrays of (fov_name, track_id, t, and embeddings) + fov_names = embedding_dataset['fov_name'].values + track_ids = embedding_dataset['track_id'].values + timepoints = embedding_dataset['t'].values + embeddings = embedding_dataset['features'].values + + # Dictionary to store displacements for each tau + displacement_per_tau = defaultdict(list) + + # Iterate over all entries in the dataset + for i in range(len(fov_names)): + fov_name = fov_names[i] + track_id = track_ids[i] + current_time = timepoints[i] + current_embedding = embeddings[i] + + # For each time point t, compute displacements for t + tau + for tau in range(1, max_tau + 1): + future_time = current_time + tau + + # Find if future_time exists for the same (fov_name, track_id) + matching_indices = np.where( + (fov_names == fov_name) & (track_ids == track_id) & (timepoints == future_time) + )[0] + + if len(matching_indices) == 1: + # Get the embedding at t + tau + future_embedding = embeddings[matching_indices[0]] + + if use_cosine: + # Compute cosine similarity + similarity = cosine_similarity(current_embedding.reshape(1, -1), future_embedding.reshape(1, -1))[0][0] + # Choose whether to use similarity or dissimilarity + if use_dissimilarity: + displacement = 1 - similarity # Cosine dissimilarity + else: + displacement = similarity # Cosine similarity + else: + # Compute the Euclidean distance, elementwise square on difference + displacement = np.sum((current_embedding - future_embedding) ** 2) + + # Store the displacement for the given tau + displacement_per_tau[tau].append(displacement) + + # Compute mean and std displacement for each tau by averaging the displacements + mean_displacement_per_tau = {tau: np.mean(displacements) for tau, displacements in displacement_per_tau.items()} + std_displacement_per_tau = {tau: np.std(displacements) for tau, displacements in displacement_per_tau.items()} + + return mean_displacement_per_tau, std_displacement_per_tau + +# Function to compute the norm of differences between embeddings at t and t + tau +def compute_displacement(embedding_dataset, max_tau=10, use_cosine=False, use_dissimilarity=False): + # Get the arrays of (fov_name, track_id, t, and embeddings) + fov_names = embedding_dataset['fov_name'].values + track_ids = embedding_dataset['track_id'].values + timepoints = embedding_dataset['t'].values + embeddings = embedding_dataset['features'].values + + # Dictionary to store displacements for each tau + displacement_per_tau = defaultdict(list) + + # Iterate over all entries in the dataset + for i in range(len(fov_names)): + fov_name = fov_names[i] + track_id = track_ids[i] + current_time = timepoints[i] + current_embedding = embeddings[i] + + # For each time point t, compute displacements for t + tau + for tau in range(1, max_tau + 1): + future_time = current_time + tau + + # Find if future_time exists for the same (fov_name, track_id) + matching_indices = np.where( + (fov_names == fov_name) & (track_ids == track_id) & (timepoints == future_time) + )[0] + + if len(matching_indices) == 1: + # Get the embedding at t + tau + future_embedding = embeddings[matching_indices[0]] + + if use_cosine: + # Compute cosine similarity + similarity = cosine_similarity(current_embedding.reshape(1, -1), future_embedding.reshape(1, -1))[0][0] + # Choose whether to use similarity or dissimilarity + if use_dissimilarity: + displacement = 1 - similarity # Cosine dissimilarity + else: + displacement = similarity # Cosine similarity + else: + # Compute the Euclidean distance, elementwise square on difference + displacement = np.sum((current_embedding - future_embedding) ** 2) + + # Store the displacement for the given tau + displacement_per_tau[tau].append(displacement) + + return displacement_per_tau + + diff --git a/viscy/translation/engine.py b/viscy/translation/engine.py index 2f88a806..774ae5fa 100644 --- a/viscy/translation/engine.py +++ b/viscy/translation/engine.py @@ -68,7 +68,7 @@ def __init__( self.l2_alpha = l2_alpha self.ms_dssim_alpha = ms_dssim_alpha - @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32) + @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32) def forward(self, preds, target): loss = 0 if self.l1_alpha: