Skip to content

Commit

Permalink
version lighting CLI example (#128)
Browse files Browse the repository at this point in the history
* version lighting CLI example

* add some documentation

* ignore slurm output

* add more tips
  • Loading branch information
mattersoflight authored and edyoshikun committed Aug 15, 2024
1 parent 8a137d4 commit 544fe8d
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ __pycache__/
# written by setuptools_scm
*/_version.py

# slurm output files
slurm-*

# Distribution / packaging
.Python
build/
Expand Down
115 changes: 115 additions & 0 deletions applications/contrastive_phenotyping/demo_cli_fit.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# See help here on how to configure hyper-parameters with config files: https://lightning.ai/docs/pytorch/stable/cli/lightning_cli_advanced.html
seed_everything: 42
trainer:
accelerator: gpu
strategy: ddp
devices: 4
num_nodes: 1
precision: 32-true
logger:
class_path: lightning.pytorch.loggers.TensorBoardLogger
init_args:
save_dir: /hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/contrastive_tune_augmentations
version: chocolate # this is the name of the experiment. The logs will be saved in save_dir/lightning_logs/version
log_graph: True
# Nesting the logger config like this is equivalent to supplying the following argument to lightning.pytorch.Trainer
# logger=TensorBoardLogger(
# "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/contrastive_tune_augmentations",
# log_graph=True,
# version="vanilla",
# )
callbacks:
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: step
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
monitor: loss/val
every_n_epochs: 1
save_top_k: 4
save_last: true
fast_dev_run: false
max_epochs: 100
log_every_n_steps: 10
enable_checkpointing: true
inference_mode: true
use_distributed_sampler: true
model:
backbone: convnext_tiny
in_channels: 2
log_batches_per_epoch: 3
log_samples_per_batch: 3
lr: 0.0002
data:
data_path: /hpc/projects/virtual_staining/2024_02_04_A549_DENV_ZIKV_timelapse/registered_chunked.zarr
tracks_path: /hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr
source_channel:
- Phase3D
- RFP
z_range: [25, 40]
batch_size: 32
num_workers: 12
initial_yx_patch_size: [384, 384]
final_yx_patch_size: [192, 192]
normalizations:
- class_path: viscy.transforms.NormalizeSampled
init_args:
keys: [Phase3D]
level: fov_statistics
subtrahend: mean
divisor: std
- class_path: viscy.transforms.ScaleIntensityRangePercentilesd
init_args:
keys: [RFP]
lower: 50
upper: 99
b_min: 0.0
b_max: 1.0
augmentations:
- class_path: viscy.transforms.RandAffined
init_args:
keys: [Phase3D, RFP]
prob: 0.8
scale_range: [0, 0.2, 0.2]
rotate_range: [3.14, 0.0, 0.0]
shear_range: [0.0, 0.01, 0.01]
padding_mode: zeros
- class_path: viscy.transforms.RandAdjustContrastd
init_args:
keys: [RFP]
prob: 0.5
gamma: [0.7, 1.3]
- class_path: viscy.transforms.RandAdjustContrastd
init_args:
keys: [Phase3D]
prob: 0.5
gamma: [0.8, 1.2]
- class_path: viscy.transforms.RandScaleIntensityd
init_args:
keys: [RFP]
prob: 0.7
factors: 0.5
- class_path: viscy.transforms.RandScaleIntensityd
init_args:
keys: [Phase3D]
prob: 0.5
factors: 0.5
- class_path: viscy.transforms.RandGaussianSmoothd
init_args:
keys: [Phase3D, RFP]
prob: 0.5
sigma_x: [0.25, 0.75]
sigma_y: [0.25, 0.75]
sigma_z: [0.0, 0.0]
- class_path: viscy.transforms.RandGaussianNoised
init_args:
keys: [RFP]
prob: 0.5
mean: 0.0
std: 0.5
- class_path: viscy.transforms.RandGaussianNoised
init_args:
keys: [Phase3D]
prob: 0.5
mean: 0.0
std: 0.2
44 changes: 44 additions & 0 deletions applications/contrastive_phenotyping/demo_cli_fit_slurm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#!/bin/bash

#SBATCH --job-name=contrastive_origin
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=4
#SBATCH --gres=gpu:4
#SBATCH --partition=gpu
#SBATCH --cpus-per-task=14
#SBATCH --mem-per-cpu=15G
#SBATCH --time=0-20:00:00

# debugging flags (optional)
# https://lightning.ai/docs/pytorch/stable/clouds/cluster_advanced.html
export NCCL_DEBUG=INFO
export PYTHONFAULTHANDLER=1


# Cleanup function to remove the temporary files
function cleanup() {
rm -rf /tmp/$SLURM_JOB_ID/*.zarr
echo "Cleanup Completed."
}

trap cleanup EXIT
# trap the EXIT signal sent to the process and invoke the cleanup.

# Activate the conda environment - specfic to your installation!
module load anaconda/2022.05
# You'll need to replace this path with path to your own conda environment.
conda activate /hpc/mydata/$USER/envs/viscy

config=./demo_cli_fit.yml

# Printing this to the stdout lets us connect the job id to config.
scontrol show job $SLURM_JOB_ID
cat $config

# Run the training CLI
srun python -m viscy.cli.contrastive_triplet fit -c $config

# Tips:
# 1. run this script with `sbatch demo_cli_fit_slurm.sh`
# 2. check the status of the job with `squeue -u $USER`
# 3. use turm to monitor the job with `turm -u first.last`. Use module load turm to load the turm module.

0 comments on commit 544fe8d

Please sign in to comment.