Skip to content

Commit

Permalink
making this readable
Browse files Browse the repository at this point in the history
  • Loading branch information
edyoshikun committed Aug 24, 2024
1 parent 7fa774d commit c497234
Showing 1 changed file with 64 additions and 38 deletions.
102 changes: 64 additions & 38 deletions part_1/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1471,14 +1471,27 @@ def min_max_scale(image:ArrayLike)->ArrayLike:
plt.show()


# %% [markdown] tags=[]
# <div class="alert alert-success">

# <h2> Checkpoint 2 </h2>
#
# Congratulations! You have completed the second checkpoint. You have:
# - Visualized the predictions and segmentations of the model. <br>
# - Evaluated the performance of the model using pixel-based metrics and segmentation-based metrics. <br>
# - Compared the performance of the model you trained with the pretrained model. <br>
#
# </div>

# %%[markdown] tags=[]
# <div class="alert alert-info">
#
# <h3> Task 3.1: Let's look at what the model is learning </h3>
#
# - Here we will visualize the encoder feature maps of the trained model. <br>
# - We will use PCA to visualize the feature maps by mapping the first 3 principal components to a colormap <br>
# - If you are unfamiliar with Principal Component Analysis (PCA), you can read up <a href="https://en.wikipedia.org/wiki/Principal_component_analysis">here</a> <br>
# - Run the next cells. We will visualize the encoder feature maps of the trained model.
# We will use PCA to visualize the feature maps by mapping the first 3 principal components to a colormap `Color` <br>
#
#
# </div>

#%%
Expand All @@ -1493,6 +1506,7 @@ def min_max_scale(image:ArrayLike)->ArrayLike:
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from typing import NamedTuple
from monai.networks.layers import GaussianFilter

def feature_map_pca(feature_map: np.array, n_components: int = 8) -> PCA:
"""
Expand Down Expand Up @@ -1553,15 +1567,29 @@ def pcs_to_rgb(feat: np.ndarray, n_components: int = 8) -> np.ndarray:
plt.imshow(phase_img[0,0,0], cmap="gray")

# %%
# TODO: modify the path to your model checkpoint
# phase2fluor_model_ckpt = natsorted(glob(
# str(top_dir/"06_image_translation/backup/phase2fluor/version_3/checkpoints/*.ckpt")
# ))[-1]
#%% [markdown] tags=[]
# <div class="alert alert-info">
#
# - For the following tasks we will use the pretrained model to extract the encoder and decoder features <br>
# - Extra: If you are done with the whole checkpoint, you can try to look at what your trained model learned.
# </div>
#%%

# TODO: rerun with pretrained
# Loading the pretrained model
pretrained_model_ckpt = (
top_dir / "06_image_translation/part1/pretrained_models/VSCyto2D/epoch=399-step=23200.ckpt"
)
# model config as before
phase2fluor_config = dict(
in_channels=1,
out_channels=2,
encoder_blocks=[3, 3, 9, 3],
dims=[96, 192, 384, 768],
decoder_conv_blocks=2,
stem_kernel_size=(1, 2, 2),
in_stack_depth=1,
pretraining=False,
)

# load model
model = VSUNet.load_from_checkpoint(
Expand All @@ -1571,8 +1599,8 @@ def pcs_to_rgb(feat: np.ndarray, n_components: int = 8) -> np.ndarray:
accelerator="gpu",
)

# %%
# extract features
# %% tags=[]
# Extract features
with torch.inference_mode():
# encoder
encoder_features = model.model.encoder(torch.from_numpy(phase_img.astype(np.float32)).to(model.device))[0]
Expand Down Expand Up @@ -1600,7 +1628,7 @@ class Color(NamedTuple):
r: float
g: float
b: float

# Defining the colors for plottting the PCA
BOP_ORANGE = Color(0.972549, 0.6784314, 0.1254902)
BOP_BLUE = Color(BOP_ORANGE.b, BOP_ORANGE.g, BOP_ORANGE.r)
GREEN = Color(0.0, 1.0, 0.0)
Expand All @@ -1620,8 +1648,10 @@ def composite_nuc_mem(
def clip_p(image: np.ndarray) -> np.ndarray:
return rescale_intensity(image.clip(*np.percentile(image, [1, 99])))

# Plot the PCA to RGB of the feature maps
def clip_highlight(image: np.ndarray) -> np.ndarray:
return rescale_intensity(image.clip(0, np.percentile(image, 99.5)))

# Plot the PCA to RGB of the feature maps
f, ax = plt.subplots(10, 1, figsize=(5, 25))
n_components = 4
ax[0].imshow(phase_img[0, 0, 0], cmap="gray")
Expand All @@ -1645,24 +1675,31 @@ def clip_p(image: np.ndarray) -> np.ndarray:
a.axis("off")
plt.tight_layout()

#%% [markdown] tags=[]
# %% [markdown] tags=["task"]
# <div class="alert alert-info">
#
# ### Task 3.2: Select a sample batch to test the range of validty of the model
# - Run the next cell to setup the your dataloader for `test` <br>
# - Select a test batch from the `test_dataloader` by changing the `batch_number` <br>
# - Examine the plot of the source and target images of the batch <br>
#
# ### Range of validity
# -
#
# <b> Note the 2D images have different focus </b> <br>
# </div>
#%%

# %%
from monai.transforms import GaussianSmooth
from monai.networks.layers import GaussianFilter
def clip_highlight(image: np.ndarray) -> np.ndarray:
return rescale_intensity(image.clip(0, np.percentile(image, 99.5)))
YX_PATCH_SIZE = (256*2,256*2)
source_channel = ["Phase3D"]
target_channel = ["Nucl", "Mem"]

normalizations = [
NormalizeSampled(
keys=source_channel + target_channel,
level="fov_statistics",
subtrahend="mean",
divisor="std",
)
]

# %%
YX_PATCH_SIZE = (256*2,256*2)
# Re-load the dataloader
phase2fluor_2D_data = HCSDataModule(
data_path,
Expand All @@ -1676,25 +1713,15 @@ def clip_highlight(image: np.ndarray) -> np.ndarray:
yx_patch_size=YX_PATCH_SIZE,
augmentations=[],
normalizations=normalizations,
# ground_truth_masks=''
)
phase2fluor_2D_data.setup("test")


# %% [markdown] tags=["task"]
# <div class="alert alert-info">
#
# ### Task 3.2: Select a sample batch to test the range of validty of the model
#
# - Select a test batch from the `test_dataloader` by changing the `batch_number` <br>
# - Examine the plot of the source and target images of the batch <br>
#
# <b> Note the 2D images have different focus </b> <br>
# </div>
#%% tags=["task"]
# ########## TODO ##############
batch_number= 3
# #######################
y_slice = slice(Y//2-256*n//2, Y//2+256*n//2)
x_slice = slice(X//2-256*n//2, X//2+256*n//2)

# Iterate through the test dataloader to get the desired batch
i = 0
for batch in phase2fluor_2D_data.test_dataloader():
Expand Down Expand Up @@ -1752,7 +1779,6 @@ def clip_highlight(image: np.ndarray) -> np.ndarray:
ax[1, 1].imshow(pred_composite[0])
ax[1,0].set_title('No perturbation')


# Select a sigma for the Gaussian filtering
# ########## TODO ##############
# Tensor dimensions (B,C,D,H,W).
Expand Down Expand Up @@ -1810,6 +1836,7 @@ def clip_highlight(image: np.ndarray) -> np.ndarray:
pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE)
ax[2, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15)
ax[2, 1].imshow(pred_composite[0])

# %% [markdown] tags=[]
# <div class="alert alert-info">
#
Expand Down Expand Up @@ -1862,7 +1889,6 @@ def clip_highlight(image: np.ndarray) -> np.ndarray:


#%% tags=["solution"]

# ########## SOLUTIONS FOR ALL POSSIBLE PLOTTINGS ##############
# This plots all perturbations

Expand Down

0 comments on commit c497234

Please sign in to comment.