Skip to content

Commit

Permalink
adding the guassian smoothing
Browse files Browse the repository at this point in the history
  • Loading branch information
edyoshikun committed Aug 24, 2024
1 parent 71840d6 commit 3d81a01
Showing 1 changed file with 294 additions and 2 deletions.
296 changes: 294 additions & 2 deletions part_1/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1615,7 +1615,7 @@ def composite_nuc_mem(
) -> np.ndarray:
c_nuc = rescale_clip(image[0]) * nuc_color
c_mem = rescale_clip(image[1]) * mem_color
return c_nuc + c_mem
return rescale_intensity(c_nuc + c_mem, out_range=(0, 1))

def clip_p(image: np.ndarray) -> np.ndarray:
return rescale_intensity(image.clip(*np.percentile(image, [1, 99])))
Expand Down Expand Up @@ -1645,6 +1645,297 @@ def clip_p(image: np.ndarray) -> np.ndarray:
a.axis("off")
plt.tight_layout()

#%% [markdown] tags=[]
# <div class="alert alert-info">
#
# ### Range of validity
# -
#
# </div>
#%%

# %%
from monai.transforms import GaussianSmooth
from monai.networks.layers import GaussianFilter
def clip_highlight(image: np.ndarray) -> np.ndarray:
return rescale_intensity(image.clip(0, np.percentile(image, 99.5)))


# %%
YX_PATCH_SIZE = (256*2,256*2)
# Re-load the dataloader
phase2fluor_2D_data = HCSDataModule(
data_path,
architecture="UNeXt2_2D",
source_channel=source_channel,
target_channel=target_channel,
z_window_size=1,
split_ratio=0.8,
batch_size=1,
num_workers=8,
yx_patch_size=YX_PATCH_SIZE,
augmentations=[],
normalizations=normalizations,
# ground_truth_masks=''
)
phase2fluor_2D_data.setup("test")


# %% [markdown] tags=["task"]
# <div class="alert alert-info">
#
# ### Task 3.2: Select a sample batch to test the range of validty of the model
#
# - Select a test batch from the `test_dataloader` by changing the `batch_number` <br>
# - Examine the plot of the source and target images of the batch <br>
#
# <b> Note the 2D images have different focus </b> <br>
# </div>
#%% tags=["task"]
# ########## TODO ##############
batch_number= 3
# #######################
# Iterate through the test dataloader to get the desired batch
i = 0
for batch in phase2fluor_2D_data.test_dataloader():
# break if we reach the desired batch
if i == batch_number-1:
break
i += 1

# Plot the batch source and target images
f, ax = plt.subplots(1, 2, figsize=(8, 12))
target_composite = composite_nuc_mem(batch["target"][0].cpu().numpy(), GREEN, MAGENTA)
ax[0].imshow(
batch["source"][0, 0, 0,y_slice,x_slice].cpu().numpy(), cmap="gray", vmin=-15, vmax=15
)
ax[1].imshow(clip_highlight(target_composite[0,y_slice,x_slice]))
for a in ax.ravel():
a.axis("off")
f.tight_layout()
plt.show()

# %% [markdown] tags=[]
# <div class="alert alert-info">
#
# ### Task 3.3: Using the selected batch to test the model's range of validity
#
# - Given the selected batch use `monai.networks.layers.GaussianFilter` to blur the images with different sigmas.
# Check the documentation <a href="https://docs.monai.io/en/stable/networks.html#gaussianfilter">here</a> <br>
# - Plot the source and predicted images comparing the source, target and added perturbations <br>
# - How is the model's predictions given the perturbations? <br>
# </div>
#%% tags=["task"]
# ########## TODO ##############
# Try out different multiples of 256 to visualize larger/smaller crops
n = 3
# ##############################
# Center cropping the image
y_slice = slice(Y//2-256*n//2, Y//2+256*n//2)
x_slice = slice(X//2-256*n//2, X//2+256*n//2)

f, ax = plt.subplots(3, 2, figsize=(8, 12))

target_composite = composite_nuc_mem(batch["target"][0].cpu().numpy(), GREEN, MAGENTA)
ax[0, 0].imshow(
batch["source"][0, 0, 0,y_slice,x_slice].cpu().numpy(), cmap="gray", vmin=-15, vmax=15
)
ax[0, 1].imshow(clip_highlight(target_composite[0,y_slice,x_slice]))
ax[0,0].set_title('Source and target')

# no perturbation
with torch.inference_mode():
phase = batch["source"].to(model.device)[:,:,:,y_slice,x_slice]
pred = model(phase).cpu().numpy()
pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE)
ax[1, 0].imshow(phase[0,0,0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15)
ax[1, 1].imshow(pred_composite[0])
ax[1,0].set_title('No perturbation')


# Select a sigma for the Gaussian filtering
# ########## TODO ##############
# Tensor dimensions (B,C,D,H,W).
# Hint: Use the GaussianFilter layer to blur the phase image. Provide the num spatial dimensions and sigmas
# Hint: Spatial (D,H,W)
gaussian_blur = GaussianFilter(....)
# #############################
with torch.inference_mode():
phase = batch["source"].to(model.device)[:,:,:,y_slice,x_slice]
phase = gaussian_blur(phase)
pred = model(phase).cpu().numpy()
pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE)
ax[2, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15)
ax[2, 1].imshow(pred_composite[0])

#%% tags=["solution"]
# ########## SOLUTION ##############
# Try out different multiples of 256 to visualize larger/smaller crops
n = 3
# ##############################
# Center cropping the image
y_slice = slice(Y//2-256*n//2, Y//2+256*n//2)
x_slice = slice(X//2-256*n//2, X//2+256*n//2)

f, ax = plt.subplots(3, 2, figsize=(8, 12))

target_composite = composite_nuc_mem(batch["target"][0].cpu().numpy(), GREEN, MAGENTA)
ax[0, 0].imshow(
batch["source"][0, 0, 0,y_slice,x_slice].cpu().numpy(), cmap="gray", vmin=-15, vmax=15
)
ax[0, 1].imshow(clip_highlight(target_composite[0,y_slice,x_slice]))
ax[0,0].set_title('Source and target')

# no perturbation
with torch.inference_mode():
phase = batch["source"].to(model.device)[:,:,:,y_slice,x_slice]
pred = model(phase).cpu().numpy()
pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE)
ax[1, 0].imshow(phase[0,0,0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15)
ax[1, 1].imshow(pred_composite[0])
ax[1,0].set_title('No perturbation')


# Select a sigma for the Gaussian filtering
# ########## SOLUTION ##############
# Tensor dimensions (B,C,D,H,W).
# Hint: Use the GaussianFilter layer to blur the phase image. Provide the num spatial dimensions and sigma
# Hint: Spatial (D,H,W). Apply the same sigma to H,W
gaussian_blur = GaussianFilter(spatial_dims=3, sigma=(0,2,2))
# #############################
with torch.inference_mode():
phase = batch["source"].to(model.device)[:,:,:,y_slice,x_slice]
phase = gaussian_blur(phase)
pred = model(phase).cpu().numpy()
pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE)
ax[2, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15)
ax[2, 1].imshow(pred_composite[0])
# %% [markdown] tags=[]
# <div class="alert alert-info">
#
# ### Task 3.3: Using the selected batch to test the model's range of validity
#
# - Scale the pixel values up/down of the phase image <br>
# - Plot the source and predicted images comparing the source, target and added perturbations <br>
# - How is the model's predictions given the perturbations? <br>
# </div>

#%% tags=["task"]
n = 3
y_slice = slice(Y//2, Y//2+256*n)
x_slice = slice(X//2, X//2+256*n)
f, ax = plt.subplots(3, 2, figsize=(8, 12))

target_composite = composite_nuc_mem(batch["target"][0].cpu().numpy(), GREEN, MAGENTA)
ax[0, 0].imshow(
batch["source"][0, 0, 0,y_slice,x_slice].cpu().numpy(), cmap="gray", vmin=-15, vmax=15
)
ax[0, 1].imshow(clip_highlight(target_composite[0,y_slice,x_slice]))
ax[0,0].set_title('Source and target')

# no perturbation
with torch.inference_mode():
phase = batch["source"].to(model.device)[:,:,:,y_slice,x_slice]
pred = model(phase).cpu().numpy()
pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE)
ax[1, 0].imshow(phase[0,0,0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15)
ax[1, 1].imshow(pred_composite[0])
ax[1,0].set_title('No perturbation')


# 2-sigma gaussian blur
with torch.inference_mode():
phase = batch["source"].to(model.device)[:,:,:,y_slice,x_slice]
# ########## TODO ##############
# Hint: Scale the phase intensity
phase = phase * ......
# #######################
pred = model(phase).cpu().numpy()
pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE)
ax[2, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15)
ax[2, 1].imshow(pred_composite[0])
ax[2,0].set_title('Gaussian Blur Sigma=2')

#%% [markdown]
# ########## TODO ##############
# - How is the model's predictions given the blurring and scaling perturbations? <br>


#%% tags=["solution"]

# ########## SOLUTIONS FOR ALL POSSIBLE PLOTTINGS ##############
# This plots all perturbations

n = 3
y_slice = slice(Y//2, Y//2+256*n)
x_slice = slice(X//2, X//2+256*n)
f, ax = plt.subplots(6, 2, figsize=(8, 12))

target_composite = composite_nuc_mem(batch["target"][0].cpu().numpy(), GREEN, MAGENTA)
ax[0, 0].imshow(
batch["source"][0, 0, 0,y_slice,x_slice].cpu().numpy(), cmap="gray", vmin=-15, vmax=15
)
ax[0, 1].imshow(clip_highlight(target_composite[0,y_slice,x_slice]))
ax[0,0].set_title('Source and target')

# no perturbation
with torch.inference_mode():
phase = batch["source"].to(model.device)[:,:,:,y_slice,x_slice]
pred = model(phase).cpu().numpy()
pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE)
ax[1, 0].imshow(phase[0,0,0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15)
ax[1, 1].imshow(pred_composite[0])
ax[1,0].set_title('No perturbation')


# 2-sigma gaussian blur
gaussian_blur = GaussianFilter(spatial_dims=3, sigma=(0,2,2))
with torch.inference_mode():
phase = batch["source"].to(model.device)[:,:,:,y_slice,x_slice]
phase = gaussian_blur(phase)
pred = model(phase).cpu().numpy()
pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE)
ax[2, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15)
ax[2, 1].imshow(pred_composite[0])
ax[2,0].set_title('Gaussian Blur Sigma=2')


# 5-sigma gaussian blur
gaussian_blur = GaussianFilter(spatial_dims=3, sigma=(0,5,5))
with torch.inference_mode():
phase = batch["source"].to(model.device)[:,:,:,y_slice,x_slice]
phase = gaussian_blur(phase)
pred = model(phase).cpu().numpy()
pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE)
ax[3, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15)
ax[3, 1].imshow(pred_composite[0])
ax[3,0].set_title('Gaussian Blur Sigma=5')


# 0.1x scaling
with torch.inference_mode():
phase = batch["source"].to(model.device)[:,:,:,y_slice,x_slice]
phase = phase*0.1
pred = model(phase).cpu().numpy()
pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE)
ax[4, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15)
ax[4, 1].imshow(pred_composite[0])
ax[4,0].set_title('0.1x scaling')

# 10x scaling
with torch.inference_mode():
phase = batch["source"].to(model.device)[:,:,:,y_slice,x_slice]
phase = phase* 10
pred = model(phase).cpu().numpy()
pred_composite = composite_nuc_mem(pred[0], BOP_BLUE, BOP_ORANGE)
ax[5, 0].imshow(phase[0, 0, 0].cpu().numpy(), cmap="gray", vmin=-15, vmax=15)
ax[5, 1].imshow(pred_composite[0])
ax[5,0].set_title('10x scaling')

for a in ax.ravel():
a.axis("off")

f.tight_layout()
# %% [markdown] tags=[]
# <div class="alert alert-success">

Expand All @@ -1654,4 +1945,5 @@ def clip_p(image: np.ndarray) -> np.ndarray:

# Congratulations! You have trained an image translation model, evaluated its performance, and explored what the network has learned.

# </div>
# </div>
# %%

0 comments on commit 3d81a01

Please sign in to comment.