Skip to content

Commit

Permalink
order of the augmentations were off on the task
Browse files Browse the repository at this point in the history
  • Loading branch information
edyoshikun committed Aug 25, 2024
1 parent 8d94687 commit 4832acd
Showing 1 changed file with 23 additions and 10 deletions.
33 changes: 23 additions & 10 deletions solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,22 +518,32 @@ def log_batch_jupyter(batch):
num_samples=2,
w_key=target_channel[0],
),
# #######################
# ##### TODO ########
# #######################
## TODO: Add Random Affine Transorms
## Write code below


# #######################
RandAdjustContrastd(keys=source_channel, prob=0.5, gamma=(0.8, 1.2)),
RandScaleIntensityd(keys=source_channel, factors=0.5, prob=0.5),
# #######################
# ##### TODO ########
# #######################
## TODO: Add Random Gaussian Noise
## Write code below



# #######################
RandGaussianSmoothd(
keys=source_channel,
sigma_x=(0.25, 0.75),
sigma_y=(0.25, 0.75),
sigma_z=(0.0, 0.0),
prob=0.5,
),
# #######################
# ##### TODO ########
# #######################
## TODO: Add Random Affine Transorms
## Write code below
## TODO: Add Random Gaussian Noise
## Write code below
]

normalizations = [
Expand Down Expand Up @@ -754,9 +764,10 @@ def log_batch_jupyter(batch):
freeze_encoder=False,
)

# %% [markdown] tags=[]
# ### Instantiate data module and trainer, test that we are setup to launch training.
# %%
# #######################
# ##### SOLUTION ########
# #######################
# Selecting the source and target channel names from the dataset.
source_channel = ["Phase3D"]
target_channel = ["Nucl", "Mem"]
Expand All @@ -774,14 +785,16 @@ def log_batch_jupyter(batch):
augmentations=augmentations,
normalizations=normalizations,
)
# #######################
# ##### SOLUTION ########
# #######################
phase2fluor_2D_data.setup("fit")
# fast_dev_run runs a single batch of data through the model to check for errors.
trainer = VSTrainer(accelerator="gpu", devices=[GPU_ID],precision='16-mixed', fast_dev_run=True)

# trainer class takes the model and the data module as inputs.
trainer.fit(phase2fluor_model, datamodule=phase2fluor_2D_data)


# %% [markdown] tags=[]
# ## View model graph.
#
Expand Down

0 comments on commit 4832acd

Please sign in to comment.