-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
87 additions
and
87 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import logging | ||
import os | ||
import sys | ||
from datetime import datetime | ||
|
||
import torch | ||
from jsonargparse import lazy_instance | ||
from lightning.pytorch import LightningDataModule, LightningModule | ||
from lightning.pytorch.cli import LightningCLI | ||
from lightning.pytorch.loggers import TensorBoardLogger | ||
|
||
from viscy.trainer import VisCyTrainer | ||
|
||
|
||
class VisCyCLI(LightningCLI): | ||
"""Extending lightning CLI arguments and defualts.""" | ||
|
||
@staticmethod | ||
def subcommands() -> dict[str, set[str]]: | ||
subcommands = LightningCLI.subcommands() | ||
subcommands["preprocess"] = {"model", "dataloaders", "datamodule"} | ||
subcommands["export"] = {"model", "dataloaders", "datamodule"} | ||
return subcommands | ||
|
||
def add_arguments_to_parser(self, parser): | ||
parser.set_defaults( | ||
{ | ||
"trainer.logger": lazy_instance( | ||
TensorBoardLogger, | ||
save_dir="", | ||
version=datetime.now().strftime(r"%Y%m%d-%H%M%S"), | ||
log_graph=True, | ||
) | ||
} | ||
) | ||
|
||
|
||
def run_cli( | ||
cli_class: type[LightningCLI], | ||
model_class: type[LightningModule], | ||
datamodule_class: type[LightningDataModule], | ||
trainer_class: type[VisCyTrainer], | ||
): | ||
""" | ||
Main Lightning CLI entry point. | ||
Parameters | ||
---------- | ||
cli_class : type[LightningCLI] | ||
Lightning CLI class | ||
model_class : type[LightningModule] | ||
Lightning module class | ||
datamodule_class : type[LightningDataModule] | ||
Lightning datamodule class | ||
trainer_class : type[VisCyTrainer] | ||
Lightning trainer class | ||
""" | ||
log_level = os.getenv("VISCY_LOG_LEVEL", logging.INFO) | ||
logging.getLogger("lightning.pytorch").setLevel(log_level) | ||
torch.set_float32_matmul_precision("high") | ||
seed = True | ||
if "preprocess" in sys.argv: | ||
seed = False | ||
model_class = LightningModule | ||
datamodule_class = None | ||
_ = cli_class( | ||
model_class=model_class, | ||
datamodule_class=datamodule_class, | ||
trainer_class=trainer_class, | ||
seed_everything_default=seed, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,44 +1,12 @@ | ||
import logging | ||
import os | ||
from datetime import datetime | ||
|
||
import torch | ||
from jsonargparse import lazy_instance | ||
from lightning.pytorch.cli import LightningCLI | ||
from lightning.pytorch.loggers import TensorBoardLogger | ||
|
||
from viscy.data.triplet import TripletDataModule | ||
from viscy.representation.engine import ContrastiveModule | ||
from viscy.cli import VisCyCLI, run_cli | ||
from viscy.data.hcs import HCSDataModule | ||
from viscy.trainer import VisCyTrainer | ||
from viscy.translation.engine import VSUNet | ||
|
||
|
||
class ContrastiveLightningCLI(LightningCLI): | ||
"""Lightning CLI with default logger.""" | ||
|
||
def add_arguments_to_parser(self, parser): | ||
parser.set_defaults( | ||
{ | ||
"trainer.logger": lazy_instance( | ||
TensorBoardLogger, | ||
save_dir="", | ||
version=datetime.now().strftime(r"%Y%m%d-%H%M%S"), | ||
log_graph=True, | ||
) | ||
} | ||
) | ||
|
||
|
||
def main(): | ||
"""Main Lightning CLI entry point.""" | ||
log_level = os.getenv("VISCY_LOG_LEVEL", logging.INFO) | ||
logging.getLogger("lightning.pytorch").setLevel(log_level) | ||
torch.set_float32_matmul_precision("high") | ||
_ = ContrastiveLightningCLI( | ||
model_class=ContrastiveModule, | ||
datamodule_class=TripletDataModule, | ||
if __name__ == "__main__": | ||
run_cli( | ||
cli_class=VisCyCLI, | ||
model_class=VSUNet, | ||
datamodule_class=HCSDataModule, | ||
trainer_class=VisCyTrainer, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,64 +1,25 @@ | ||
import logging | ||
import os | ||
import sys | ||
from datetime import datetime | ||
|
||
import torch | ||
from jsonargparse import lazy_instance | ||
from lightning.pytorch import LightningModule | ||
from lightning.pytorch.cli import LightningCLI | ||
from lightning.pytorch.loggers import TensorBoardLogger | ||
|
||
from viscy.cli import VisCyCLI, run_cli | ||
from viscy.data.hcs import HCSDataModule | ||
from viscy.trainer import VisCyTrainer | ||
from viscy.translation.engine import VSUNet | ||
|
||
|
||
class VSLightningCLI(LightningCLI): | ||
class TranslationCLI(VisCyCLI): | ||
"""Extending lightning CLI arguments and defualts.""" | ||
|
||
@staticmethod | ||
def subcommands() -> dict[str, set[str]]: | ||
subcommands = LightningCLI.subcommands() | ||
subcommands["preprocess"] = {"model", "dataloaders", "datamodule"} | ||
subcommands["export"] = {"model", "dataloaders", "datamodule"} | ||
return subcommands | ||
|
||
def add_arguments_to_parser(self, parser): | ||
super().add_arguments_to_parser(parser) | ||
if "preprocess" not in sys.argv: | ||
parser.link_arguments("data.yx_patch_size", "model.example_input_yx_shape") | ||
parser.link_arguments("model.architecture", "data.architecture") | ||
parser.set_defaults( | ||
{ | ||
"trainer.logger": lazy_instance( | ||
TensorBoardLogger, | ||
save_dir="", | ||
version=datetime.now().strftime(r"%Y%m%d-%H%M%S"), | ||
log_graph=True, | ||
) | ||
} | ||
) | ||
|
||
|
||
def main(): | ||
"""Main Lightning CLI entry point.""" | ||
log_level = os.getenv("VISCY_LOG_LEVEL", logging.INFO) | ||
logging.getLogger("lightning.pytorch").setLevel(log_level) | ||
torch.set_float32_matmul_precision("high") | ||
model_class = VSUNet | ||
datamodule_class = HCSDataModule | ||
seed = True | ||
if "preprocess" in sys.argv: | ||
seed = False | ||
model_class = LightningModule | ||
datamodule_class = None | ||
_ = VSLightningCLI( | ||
model_class=model_class, | ||
datamodule_class=datamodule_class, | ||
if __name__ == "__main__": | ||
run_cli( | ||
cli_class=TranslationCLI, | ||
model_class=VSUNet, | ||
datamodule_class=HCSDataModule, | ||
trainer_class=VisCyTrainer, | ||
seed_everything_default=seed, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |