Skip to content

Commit

Permalink
adding the feature maps PCA to RGB.
Browse files Browse the repository at this point in the history
  • Loading branch information
edyoshikun committed Aug 24, 2024
1 parent 67a23bb commit 2ff67f8
Showing 1 changed file with 177 additions and 4 deletions.
181 changes: 177 additions & 4 deletions part_1/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1470,15 +1470,188 @@ def min_max_scale(image:ArrayLike)->ArrayLike:
plt.tight_layout()
plt.show()



# %%[markdown] tags=[]
# <div class="alert alert-info">
#
# <h3> Task 3.1: Let's look at what the model is learning </h3>
#
# - Here we will visualize the encoder feature maps of the trained model. <br>
# - We will use PCA to visualize the feature maps by mapping the first 3 principal components to a colormap <br>
# </div>

#%%
"""
Script to visualize the encoder feature maps of a trained model.
Using PCA to visualize feature maps is inspired by
https://doi.org/10.48550/arXiv.2304.07193 (Oquab et al., 2023).
"""
from matplotlib.patches import Rectangle
from skimage.exposure import rescale_intensity
from skimage.transform import downscale_local_mean
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from typing import NamedTuple

def feature_map_pca(feature_map: np.array, n_components: int = 8) -> PCA:
"""
Compute PCA on a feature map.
:param np.array feature_map: (C, H, W) feature map
:param int n_components: number of components to keep
:return: PCA: fit sklearn PCA object
"""
# (C, H, W) -> (C, H*W)
feat = feature_map.reshape(feature_map.shape[0], -1)
pca = PCA(n_components=n_components)
pca.fit(feat)
return pca

def pcs_to_rgb(feat: np.ndarray, n_components: int = 8) -> np.ndarray:
pca = feature_map_pca(feat[0], n_components=n_components)
pc_first_3 = pca.components_[:3].reshape(3, *feat.shape[-2:])
return np.stack(
[rescale_intensity(pc, out_range=(0, 1)) for pc in pc_first_3], axis=-1
)
#%%
# Load the test dataset
test_data_path = top_dir / "06_image_translation/part1/test/a549_hoechst_cellmask_test.zarr"
test_dataset = open_ome_zarr(test_data_path)

# Looking at the test dataset
print('Test dataset:')
test_dataset.print_tree()

#%% [markdown] tags=[]
# <div class="alert alert-info">
#
# - Change the `fov` and `crop` size to visualize the feature maps of the encoder and decoder <br>
# Note: the crop should be a multiple of 384
# </div>
#%%
# Load one position
row = 0
col = 0
center_index = 2
n = 1
crop = 384 * n
fov = 10

# normalize phase
norm_meta = test_dataset.zattrs["normalization"]["Phase3D"]["dataset_statistics"]

# Get the OME-Zarr metadata
Y,X = test_dataset[f"0/0/{fov}"].data.shape[-2:]
test_dataset.channel_names
phase_idx= test_dataset.channel_names.index('Phase3D')
assert crop//2 < Y and crop//2 < Y , "Crop size larger than the image. Check the image shape"

phase_img = test_dataset[f"0/0/{fov}/0"][:, phase_idx:phase_idx+1,0:1, Y//2 - crop // 2 : Y//2 + crop // 2, X//2 - crop // 2 : X//2 + crop // 2]
fluo = test_dataset[f"0/0/{fov}/0"][0, 1:3, 0, Y//2 - crop // 2 : Y//2 + crop // 2, X//2 - crop // 2 : X//2 + crop // 2]

phase_img = (phase_img - norm_meta["median"]) / norm_meta["iqr"]
plt.imshow(phase_img[0,0,0], cmap="gray")

# %%
# TODO: modify the path to your model checkpoint
# phase2fluor_model_ckpt = natsorted(glob(
# str(top_dir/"06_image_translation/backup/phase2fluor/version_3/checkpoints/*.ckpt")
# ))[-1]

# TODO: rerun with pretrained
pretrained_model_ckpt = (
top_dir / "06_image_translation/part1/pretrained_models/VSCyto2D/epoch=399-step=23200.ckpt"
)

# load model
model = VSUNet.load_from_checkpoint(
pretrained_model_ckpt,
architecture="UNeXt2_2D",
model_config=phase2fluor_config.copy(),
accelerator="gpu",
)

# %%
# extract features
with torch.inference_mode():
# encoder
encoder_features = model.model.encoder(torch.from_numpy(phase_img.astype(np.float32)).to(model.device))[0]
encoder_features_np = [f.detach().cpu().numpy() for f in encoder_features]

# Print the encoder features shapes
for f in encoder_features_np:
print(f.shape)

# decoder
features = encoder_features.copy()
features.reverse()
feat = features[0]
features.append(None)
decoder_features_np = []
for skip, stage in zip(features[1:], model.model.decoder.decoder_stages):
feat = stage(feat, skip)
decoder_features_np.append(feat.detach().cpu().numpy())
for f in decoder_features_np:
print(f.shape)
prediction = model.model.head(feat).detach().cpu().numpy()

# Defining the colors for plotting
class Color(NamedTuple):
r: float
g: float
b: float

BOP_ORANGE = Color(0.972549, 0.6784314, 0.1254902)
BOP_BLUE = Color(BOP_ORANGE.b, BOP_ORANGE.g, BOP_ORANGE.r)
GREEN = Color(0.0, 1.0, 0.0)
MAGENTA = Color(1.0, 0.0, 1.0)

# Defining the functions to rescale the image and composite the nuclear and membrane images
def rescale_clip(image: torch.Tensor) -> np.ndarray:
return rescale_intensity(image, out_range=(0, 1))[..., None].repeat(3, axis=-1)

def composite_nuc_mem(
image: torch.Tensor, nuc_color: Color, mem_color: Color
) -> np.ndarray:
c_nuc = rescale_clip(image[0]) * nuc_color
c_mem = rescale_clip(image[1]) * mem_color
return c_nuc + c_mem

def clip_p(image: np.ndarray) -> np.ndarray:
return rescale_intensity(image.clip(*np.percentile(image, [1, 99])))

# Plot the PCA to RGB of the feature maps

f, ax = plt.subplots(10, 1, figsize=(5, 25))
n_components = 4
ax[0].imshow(phase_img[0, 0, 0], cmap="gray")
ax[0].set_title(f"Phase {phase_img.shape[1:]}")
ax[-1].imshow(clip_p(composite_nuc_mem(fluo, GREEN, MAGENTA)))
ax[-1].set_title("Fluorescence")

for level, feat in enumerate(encoder_features_np):
ax[level + 1].imshow(pcs_to_rgb(feat, n_components=n_components))
ax[level + 1].set_title(f"Encoder stage {level+1} {feat.shape[1:]}")

for level, feat in enumerate(decoder_features_np):
ax[5 + level].imshow(pcs_to_rgb(feat, n_components=n_components))
ax[5 + level].set_title(f"Decoder stage {level+1} {feat.shape[1:]}")

pred_comp = composite_nuc_mem(prediction[0, :, 0], BOP_BLUE, BOP_ORANGE)
ax[-2].imshow(clip_p(pred_comp))
ax[-2].set_title(f"Prediction {prediction.shape[1:]}")

for a in ax.ravel():
a.axis("off")
plt.tight_layout()

# %% [markdown] tags=[]
# <div class="alert alert-success">

# <h2>
# 🎉 The end of the notebook 🎉
# Continue to Part 2: Image translation with generative models.
# </h2>

# Congratulations! You have trained an image translation model and evaluated its performance.
# </div>
# Congratulations! You have trained an image translation model, evaluated its performance, and explored what the network has learned.

# %%
# </div>

0 comments on commit 2ff67f8

Please sign in to comment.