From 9689fa64f6d6ae18715ea52e646631d7f4ac4259 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 15 Aug 2024 21:06:27 -0700 Subject: [PATCH 1/2] animated latent space --- .../contrastive_cli/plot_embeddings.py | 30 +++++++++++++------ 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/applications/contrastive_phenotyping/contrastive_cli/plot_embeddings.py b/applications/contrastive_phenotyping/contrastive_cli/plot_embeddings.py index 6a351ddf..c8f5e574 100644 --- a/applications/contrastive_phenotyping/contrastive_cli/plot_embeddings.py +++ b/applications/contrastive_phenotyping/contrastive_cli/plot_embeddings.py @@ -1,13 +1,14 @@ # %% from pathlib import Path +import matplotlib.pyplot as plt import numpy as np import pandas as pd import plotly.express as px import seaborn as sns from sklearn.preprocessing import StandardScaler from umap import UMAP -import matplotlib.pyplot as plt + from viscy.light.embedding_writer import read_embedding_dataset from viscy.data.triplet import TripletDataset, TripletDataModule from iohub import open_ome_zarr @@ -16,7 +17,7 @@ # %% Paths and parameters. features_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/contrastive_tune_augmentations/predict/2024_02_04/tokenized-drop_path_0_0.zarr" + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/contrastive_tune_augmentations/predict/2024_02_04/tokenized-drop_path_0_0-2024-06-13.zarr" ) data_path = Path( "/hpc/projects/virtual_staining/2024_02_04_A549_DENV_ZIKV_timelapse/registered_chunked.zarr" @@ -275,17 +276,28 @@ def load_annotation(da, path, name, categories: dict | None = None): # %% # interactive scatter plot to associate clusters with specific cells - -px.scatter( - data_frame=pd.DataFrame( - {k: v for k, v in features.coords.items() if k != "features"} - ), +df = pd.DataFrame({k: v for k, v in features.coords.items() if k != "features"}) +df["infection"] = infection.values +df["division"] = division.values +df["well"] = df["fov_name"].str.rsplit("/", n=1).str[0] +df["fov_track_id"] = df["fov_name"] + "-" + df["track_id"].astype(str) +# select row B (DENV) +df = df[df["fov_name"].str.contains("B")] +df.sort_values("t", inplace=True) + +g = px.scatter( + data_frame=df, x="UMAP1", y="UMAP2", - color=(infection.astype(str) + " " + division.astype(str)).rename("annotation"), + symbol="infection", + color="well", hover_name="fov_name", - hover_data=["id", "t"], + hover_data=["id", "t", "track_id"], + animation_frame="t", + animation_group="fov_track_id", ) +g.update_layout(width=800, height=600) + # %% # cluster features in heatmap directly From db2892f3f712f650bc349e84664801e9f58d7567 Mon Sep 17 00:00:00 2001 From: Shalin Mehta Date: Thu, 15 Aug 2024 21:36:52 -0700 Subject: [PATCH 2/2] delete duplicate umap calculation --- .../contrastive_cli/plot_embeddings.py | 47 +++++++++---------- 1 file changed, 21 insertions(+), 26 deletions(-) diff --git a/applications/contrastive_phenotyping/contrastive_cli/plot_embeddings.py b/applications/contrastive_phenotyping/contrastive_cli/plot_embeddings.py index c8f5e574..63ce73d4 100644 --- a/applications/contrastive_phenotyping/contrastive_cli/plot_embeddings.py +++ b/applications/contrastive_phenotyping/contrastive_cli/plot_embeddings.py @@ -153,10 +153,29 @@ # %% # Compute UMAP over all features features = embedding_dataset["features"] +# or select a well: +# features = features[features["fov_name"].str.contains("B/4")] + scaled_features = StandardScaler().fit_transform(features.values) umap = UMAP() # Fit UMAP on all features embedding = umap.fit_transform(scaled_features) + +# %% +# Add UMAP coordinates to the dataset + +features = ( + features.assign_coords(UMAP1=("sample", embedding[:, 0])) + .assign_coords(UMAP2=("sample", embedding[:, 1])) + .set_index(sample=["UMAP1", "UMAP2"], append=True) +) +features + + +sns.scatterplot( + x=features["UMAP1"], y=features["UMAP2"], hue=features["t"], s=7, alpha=0.8 +) + # %% # Transform the track features scaled_features_track_umap = umap.transform(scaled_features_track) @@ -173,15 +192,9 @@ color="blue", ) plt.show() -# %% -# load all unprojected features: -features = embedding_dataset["features"] -# or select a well: -# features = features[features["fov_name"].str.contains("B/4")] -features # %% -# examine raw features +# examine random features random_samples = np.random.randint(0, embedding_dataset.sizes["sample"], 700) # concatenate fov_name, track_id, and t to create a unique sample identifier sample_id = ( @@ -192,7 +205,7 @@ + features["t"][random_samples].astype(str) ) px.imshow( - features.values[random_samples], + scaled_features[random_samples], labels={ "x": "feature", "y": "sample", @@ -202,24 +215,6 @@ # show fov_name as y-axis ) -# %% -scaled_features = StandardScaler().fit_transform(features.values) - -umap = UMAP() - -embedding = umap.fit_transform(scaled_features) -features = ( - features.assign_coords(UMAP1=("sample", embedding[:, 0])) - .assign_coords(UMAP2=("sample", embedding[:, 1])) - .set_index(sample=["UMAP1", "UMAP2"], append=True) -) -features - -# %% -sns.scatterplot( - x=features["UMAP1"], y=features["UMAP2"], hue=features["t"], s=7, alpha=0.8 -) - # %% def load_annotation(da, path, name, categories: dict | None = None):