Skip to content

Commit

Permalink
Migrate from wandb to tensorboard (#122)
Browse files Browse the repository at this point in the history
* wip: use lightning's tensorboard logger instead of wandb

* private logging methods

* log center slice only

* fix tensor cloning

* only log metrics on epoch

* add simple demo training script

* fix flaky test

* log graph + profiling

* switch to simple profiler

---------

Co-authored-by: Shalin Mehta <[email protected]>
  • Loading branch information
2 people authored and edyoshikun committed Aug 7, 2024
1 parent 3184b76 commit 64ea7cf
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 314 deletions.
45 changes: 45 additions & 0 deletions applications/contrastive_phenotyping/demo_fit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import DeviceStatsMonitor


from viscy.data.triplet import TripletDataModule
from viscy.light.engine import ContrastiveModule


def main():
dm = TripletDataModule(
data_path="/hpc/projects/virtual_staining/2024_02_04_A549_DENV_ZIKV_timelapse/registered_chunked.zarr",
tracks_path="/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr",
source_channel=["Phase3D", "RFP"],
z_range=(20, 35),
batch_size=16,
num_workers=10,
initial_yx_patch_size=(384, 384),
final_yx_patch_size=(224, 224),
)
model = ContrastiveModule(
backbone="convnext_tiny",
in_channels=2,
log_batches_per_epoch=2,
log_samples_per_batch=3,
)
trainer = Trainer(
max_epochs=5,
limit_train_batches=10,
limit_val_batches=5,
logger=TensorBoardLogger(
"/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/test_tb",
log_graph=True,
default_hp_metric=True,
),
log_every_n_steps=1,
callbacks=[ModelCheckpoint()],
profiler="simple", # other options: "advanced" uses cprofiler, "pytorch" uses pytorch profiler.
)
trainer.fit(model, dm)


if __name__ == "__main__":
main()
6 changes: 2 additions & 4 deletions tests/preprocessing/test_pixel_ratio.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
def test_sematic_class_weights(small_hcs_dataset):
weights = sematic_class_weights(small_hcs_dataset, "GFP")
assert weights.shape == (3,)
assert_allclose(weights[0], 1.0)
assert_allclose(weights[0], 1.0, atol=1e-5)
# infinity
assert weights[1] > 1.0
assert weights[2] > 1.0
assert sematic_class_weights(
small_hcs_dataset, "GFP", num_classes=2
).shape == (2,)
assert sematic_class_weights(small_hcs_dataset, "GFP", num_classes=2).shape == (2,)
8 changes: 5 additions & 3 deletions viscy/data/hcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,11 @@ def _read_norm_meta(fov: Position) -> NormMeta | None:
for channel, channel_values in norm_meta.items():
for level, level_values in channel_values.items():
for stat, value in level_values.items():
norm_meta[channel][level][stat] = torch.tensor(
value, dtype=torch.float32
)
if isinstance(value, Tensor):
value = value.clone().float()
else:
value = torch.tensor(value, dtype=torch.float32)
norm_meta[channel][level][stat] = value
return norm_meta


Expand Down
Loading

0 comments on commit 64ea7cf

Please sign in to comment.