Skip to content

Commit

Permalink
Animated latent space (#137)
Browse files Browse the repository at this point in the history
* animated latent space

* delete duplicate umap calculation

---------

Co-authored-by: Shalin Mehta <[email protected]>
  • Loading branch information
ziw-liu and mattersoflight authored Aug 16, 2024
1 parent d8b2b2f commit faf26a2
Showing 1 changed file with 42 additions and 35 deletions.
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# %%
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from umap import UMAP
import matplotlib.pyplot as plt

from viscy.light.embedding_writer import read_embedding_dataset
from viscy.data.triplet import TripletDataset, TripletDataModule
from iohub import open_ome_zarr
Expand All @@ -16,7 +17,7 @@
# %% Paths and parameters.

features_path = Path(
"/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/contrastive_tune_augmentations/predict/2024_02_04/tokenized-drop_path_0_0.zarr"
"/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/contrastive_tune_augmentations/predict/2024_02_04/tokenized-drop_path_0_0-2024-06-13.zarr"
)
data_path = Path(
"/hpc/projects/virtual_staining/2024_02_04_A549_DENV_ZIKV_timelapse/registered_chunked.zarr"
Expand Down Expand Up @@ -152,10 +153,29 @@
# %%
# Compute UMAP over all features
features = embedding_dataset["features"]
# or select a well:
# features = features[features["fov_name"].str.contains("B/4")]

scaled_features = StandardScaler().fit_transform(features.values)
umap = UMAP()
# Fit UMAP on all features
embedding = umap.fit_transform(scaled_features)

# %%
# Add UMAP coordinates to the dataset

features = (
features.assign_coords(UMAP1=("sample", embedding[:, 0]))
.assign_coords(UMAP2=("sample", embedding[:, 1]))
.set_index(sample=["UMAP1", "UMAP2"], append=True)
)
features


sns.scatterplot(
x=features["UMAP1"], y=features["UMAP2"], hue=features["t"], s=7, alpha=0.8
)

# %%
# Transform the track features
scaled_features_track_umap = umap.transform(scaled_features_track)
Expand All @@ -172,15 +192,9 @@
color="blue",
)
plt.show()
# %%
# load all unprojected features:
features = embedding_dataset["features"]
# or select a well:
# features = features[features["fov_name"].str.contains("B/4")]
features

# %%
# examine raw features
# examine random features
random_samples = np.random.randint(0, embedding_dataset.sizes["sample"], 700)
# concatenate fov_name, track_id, and t to create a unique sample identifier
sample_id = (
Expand All @@ -191,7 +205,7 @@
+ features["t"][random_samples].astype(str)
)
px.imshow(
features.values[random_samples],
scaled_features[random_samples],
labels={
"x": "feature",
"y": "sample",
Expand All @@ -201,24 +215,6 @@
# show fov_name as y-axis
)

# %%
scaled_features = StandardScaler().fit_transform(features.values)

umap = UMAP()

embedding = umap.fit_transform(scaled_features)
features = (
features.assign_coords(UMAP1=("sample", embedding[:, 0]))
.assign_coords(UMAP2=("sample", embedding[:, 1]))
.set_index(sample=["UMAP1", "UMAP2"], append=True)
)
features

# %%
sns.scatterplot(
x=features["UMAP1"], y=features["UMAP2"], hue=features["t"], s=7, alpha=0.8
)


# %%
def load_annotation(da, path, name, categories: dict | None = None):
Expand Down Expand Up @@ -275,17 +271,28 @@ def load_annotation(da, path, name, categories: dict | None = None):

# %%
# interactive scatter plot to associate clusters with specific cells

px.scatter(
data_frame=pd.DataFrame(
{k: v for k, v in features.coords.items() if k != "features"}
),
df = pd.DataFrame({k: v for k, v in features.coords.items() if k != "features"})
df["infection"] = infection.values
df["division"] = division.values
df["well"] = df["fov_name"].str.rsplit("/", n=1).str[0]
df["fov_track_id"] = df["fov_name"] + "-" + df["track_id"].astype(str)
# select row B (DENV)
df = df[df["fov_name"].str.contains("B")]
df.sort_values("t", inplace=True)

g = px.scatter(
data_frame=df,
x="UMAP1",
y="UMAP2",
color=(infection.astype(str) + " " + division.astype(str)).rename("annotation"),
symbol="infection",
color="well",
hover_name="fov_name",
hover_data=["id", "t"],
hover_data=["id", "t", "track_id"],
animation_frame="t",
animation_group="fov_track_id",
)
g.update_layout(width=800, height=600)


# %%
# cluster features in heatmap directly
Expand Down

0 comments on commit faf26a2

Please sign in to comment.