Skip to content

Commit

Permalink
inherit base CLI class for tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
ziw-liu committed Oct 10, 2024
1 parent 9d99f23 commit dfb9fa7
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 87 deletions.
71 changes: 71 additions & 0 deletions viscy/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import logging
import os
import sys
from datetime import datetime

import torch
from jsonargparse import lazy_instance
from lightning.pytorch import LightningDataModule, LightningModule
from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.loggers import TensorBoardLogger

from viscy.trainer import VisCyTrainer


class VisCyCLI(LightningCLI):
"""Extending lightning CLI arguments and defualts."""

@staticmethod
def subcommands() -> dict[str, set[str]]:
subcommands = LightningCLI.subcommands()
subcommands["preprocess"] = {"model", "dataloaders", "datamodule"}
subcommands["export"] = {"model", "dataloaders", "datamodule"}
return subcommands

def add_arguments_to_parser(self, parser):
parser.set_defaults(
{
"trainer.logger": lazy_instance(
TensorBoardLogger,
save_dir="",
version=datetime.now().strftime(r"%Y%m%d-%H%M%S"),
log_graph=True,
)
}
)


def run_cli(
cli_class: type[LightningCLI],
model_class: type[LightningModule],
datamodule_class: type[LightningDataModule],
trainer_class: type[VisCyTrainer],
):
"""
Main Lightning CLI entry point.
Parameters
----------
cli_class : type[LightningCLI]
Lightning CLI class
model_class : type[LightningModule]
Lightning module class
datamodule_class : type[LightningDataModule]
Lightning datamodule class
trainer_class : type[VisCyTrainer]
Lightning trainer class
"""
log_level = os.getenv("VISCY_LOG_LEVEL", logging.INFO)
logging.getLogger("lightning.pytorch").setLevel(log_level)
torch.set_float32_matmul_precision("high")
seed = True
if "preprocess" in sys.argv:
seed = False
model_class = LightningModule
datamodule_class = None
_ = cli_class(
model_class=model_class,
datamodule_class=datamodule_class,
trainer_class=trainer_class,
seed_everything_default=seed,
)
48 changes: 8 additions & 40 deletions viscy/representation/__main__.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,12 @@
import logging
import os
from datetime import datetime

import torch
from jsonargparse import lazy_instance
from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.loggers import TensorBoardLogger

from viscy.data.triplet import TripletDataModule
from viscy.representation.engine import ContrastiveModule
from viscy.cli import VisCyCLI, run_cli
from viscy.data.hcs import HCSDataModule
from viscy.trainer import VisCyTrainer
from viscy.translation.engine import VSUNet


class ContrastiveLightningCLI(LightningCLI):
"""Lightning CLI with default logger."""

def add_arguments_to_parser(self, parser):
parser.set_defaults(
{
"trainer.logger": lazy_instance(
TensorBoardLogger,
save_dir="",
version=datetime.now().strftime(r"%Y%m%d-%H%M%S"),
log_graph=True,
)
}
)


def main():
"""Main Lightning CLI entry point."""
log_level = os.getenv("VISCY_LOG_LEVEL", logging.INFO)
logging.getLogger("lightning.pytorch").setLevel(log_level)
torch.set_float32_matmul_precision("high")
_ = ContrastiveLightningCLI(
model_class=ContrastiveModule,
datamodule_class=TripletDataModule,
if __name__ == "__main__":
run_cli(
cli_class=VisCyCLI,
model_class=VSUNet,
datamodule_class=HCSDataModule,
trainer_class=VisCyTrainer,
)


if __name__ == "__main__":
main()
55 changes: 8 additions & 47 deletions viscy/translation/__main__.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,25 @@
import logging
import os
import sys
from datetime import datetime

import torch
from jsonargparse import lazy_instance
from lightning.pytorch import LightningModule
from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.loggers import TensorBoardLogger

from viscy.cli import VisCyCLI, run_cli
from viscy.data.hcs import HCSDataModule
from viscy.trainer import VisCyTrainer
from viscy.translation.engine import VSUNet


class VSLightningCLI(LightningCLI):
class TranslationCLI(VisCyCLI):
"""Extending lightning CLI arguments and defualts."""

@staticmethod
def subcommands() -> dict[str, set[str]]:
subcommands = LightningCLI.subcommands()
subcommands["preprocess"] = {"model", "dataloaders", "datamodule"}
subcommands["export"] = {"model", "dataloaders", "datamodule"}
return subcommands

def add_arguments_to_parser(self, parser):
super().add_arguments_to_parser(parser)
if "preprocess" not in sys.argv:
parser.link_arguments("data.yx_patch_size", "model.example_input_yx_shape")
parser.link_arguments("model.architecture", "data.architecture")
parser.set_defaults(
{
"trainer.logger": lazy_instance(
TensorBoardLogger,
save_dir="",
version=datetime.now().strftime(r"%Y%m%d-%H%M%S"),
log_graph=True,
)
}
)


def main():
"""Main Lightning CLI entry point."""
log_level = os.getenv("VISCY_LOG_LEVEL", logging.INFO)
logging.getLogger("lightning.pytorch").setLevel(log_level)
torch.set_float32_matmul_precision("high")
model_class = VSUNet
datamodule_class = HCSDataModule
seed = True
if "preprocess" in sys.argv:
seed = False
model_class = LightningModule
datamodule_class = None
_ = VSLightningCLI(
model_class=model_class,
datamodule_class=datamodule_class,
if __name__ == "__main__":
run_cli(
cli_class=TranslationCLI,
model_class=VSUNet,
datamodule_class=HCSDataModule,
trainer_class=VisCyTrainer,
seed_everything_default=seed,
)


if __name__ == "__main__":
main()

0 comments on commit dfb9fa7

Please sign in to comment.