diff --git a/pyproject.toml b/pyproject.toml index b275e24..c0eb60a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ dynamic = ["version"] requires-python = ">=3.9" dependencies = [ "astropy", # Used to load fits files of sources to query HSC cutout server + "pytorch-ignite", # Used for distributed training, logging, etc. "toml", # Used to load configuration files as dictionaries "torch", # Used for CNN model and in train.py "torchvision", # Used in hsc data loader, example autoencoder, and CNN model data set diff --git a/src/fibad/models/example_autoencoder.py b/src/fibad/models/example_autoencoder.py index ae4ae05..3905e87 100644 --- a/src/fibad/models/example_autoencoder.py +++ b/src/fibad/models/example_autoencoder.py @@ -3,7 +3,8 @@ # This example model is taken from the autoenocoder tutorial here # https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial9/AE_CIFAR10.html -import numpy as np +# The train function has been converted into train_step for use with pytorch-ignite. + import torch import torch.nn as nn import torch.nn.functional as F # noqa N812 @@ -37,6 +38,9 @@ def __init__(self, model_config, shape=(5, 250, 250)): self._init_encoder() self._init_decoder() + # create this here for use in `train_step`, to avoid recreating at each step. + self.optimizer = self._optimizer() + def conv2d_multi_layer(self, input_size, num_applications, **kwargs) -> int: for _ in range(num_applications): input_size = self.conv2d_output_size(input_size, **kwargs) @@ -103,37 +107,34 @@ def forward(self, x): x_hat = self._eval_decoder(z) return x_hat - def train(self, trainloader, device=None): - self.optimizer = self._optimizer() + def train_step(self, batch): + """This function contains the logic for a single training step. i.e. the + contents of the inner loop of a ML training process. - torch.set_grad_enabled(True) + Parameters + ---------- + batch : tuple + A tuple containing the inputs and labels for the current batch. - # print(f"len(trainloder) = {len(trainloader)}") - for epoch in range(self.config.get("epochs", 2)): - running_loss = 0.0 - for batch_num, data in enumerate(trainloader, 0): - # When we run on a supervised dataset like CIFAR10, drop the labels given by the data loader - x = data[0] if isinstance(data, tuple) else data + Returns + ------- + Current loss value + The loss value for the current batch. + """ + # When we run on a supervised dataset like CIFAR10, drop the labels given by the data loader + x = batch[0] if isinstance(batch, tuple) else batch - x = x.to(device) - x_hat = self.forward(x) - loss = F.mse_loss(x, x_hat, reduction="none") - loss = loss.sum(dim=[1, 2, 3]).mean(dim=[0]) + x_hat = self.forward(x) + loss = F.mse_loss(x, x_hat, reduction="none") + loss = loss.sum(dim=[1, 2, 3]).mean(dim=[0]) - self.optimizer.zero_grad() + self.optimizer.zero_grad() - loss.backward() + loss.backward() - self.optimizer.step() - running_loss += loss.item() + self.optimizer.step() - # Log every 2000 batches in an epoch, or the end of the epoch - # Ensure we get one log message at the end of every epoch even if the - # data size is less than 2000. - log_freq = np.min([2000, len(trainloader)]) - if batch_num % log_freq == log_freq - 1: - print(f"[{epoch + 1}, {batch_num + 1}] loss: {running_loss / 2000}") - running_loss = 0.0 + return {"loss": loss.item()} def _optimizer(self): return optim.Adam(self.parameters(), lr=1e-3) diff --git a/src/fibad/models/example_cnn_classifier.py b/src/fibad/models/example_cnn_classifier.py index 16a5f15..0fd4a85 100644 --- a/src/fibad/models/example_cnn_classifier.py +++ b/src/fibad/models/example_cnn_classifier.py @@ -9,7 +9,6 @@ import torch.nn.functional as F # noqa N812 import torch.optim as optim -# extra long import here to address a circular import issue from .model_registry import fibad_model logger = logging.getLogger(__name__) @@ -17,7 +16,7 @@ @fibad_model class ExampleCNN(nn.Module): - def __init__(self, model_config): + def __init__(self, model_config, shape): super().__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) @@ -28,6 +27,11 @@ def __init__(self, model_config): self.config = model_config + # Optimizer and criterion could be set directly, i.e. `self.optimizer = optim.SGD(...)` + # but we define them as methods as a way to allow for more flexibility in the future. + self.optimizer = self._optimizer() + self.criterion = self._criterion() + def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) @@ -37,25 +41,28 @@ def forward(self, x): x = self.fc3(x) return x - def train(self, trainloader, device=None): - self.optimizer = self._optimizer() - self.criterion = self._criterion() + def train_step(self, batch): + """This function contains the logic for a single training step. i.e. the + contents of the inner loop of a ML training process. + + Parameters + ---------- + batch : tuple + A tuple containing the inputs and labels for the current batch. + + Returns + ------- + Current loss value + The loss value for the current batch. + """ + inputs, labels = batch - for epoch in range(self.config.get("epochs", 2)): - running_loss = 0.0 - for i, data in enumerate(trainloader, 0): - inputs, labels = data - inputs, labels = inputs.to(device), labels.to(device) - - self.optimizer.zero_grad() - outputs = self(inputs) - loss = self.criterion(outputs, labels) - loss.backward() - self.optimizer.step() - running_loss += loss.item() - if i % 2000 == 1999: - logger.info(f"[{epoch + 1}, {i + 1}] loss: {running_loss / 2000}") - running_loss = 0.0 + self.optimizer.zero_grad() + outputs = self(inputs) + loss = self.criterion(outputs, labels) + loss.backward() + self.optimizer.step() + return {"loss": loss.item()} def _criterion(self): return nn.CrossEntropyLoss() diff --git a/src/fibad/train.py b/src/fibad/train.py index aa727fe..68b22b4 100644 --- a/src/fibad/train.py +++ b/src/fibad/train.py @@ -1,6 +1,8 @@ import logging import torch +from ignite import distributed as idist +from ignite.engine import Engine, Events from fibad.data_loaders.data_loader_registry import fetch_data_loader_class from fibad.models.model_registry import fetch_model_class @@ -9,7 +11,7 @@ def run(config): - """Placeholder for training code. + """Run the training process for a given model and data loader. Parameters ---------- @@ -18,35 +20,104 @@ def run(config): dict """ + # Fetch data loader class specified in config and create an instance of it data_loader_cls = fetch_data_loader_class(config) - fibad_data_loader = data_loader_cls(config.get("data_loader", {})) - data_loader = fibad_data_loader.get_data_loader() + data_loader = data_loader_cls(config.get("data_loader", {})) + # Get the pytorch.dataset from dataloader, and use it to create a distributed dataloader + data_set = data_loader.data_set() + dist_data_loader = _train_data_loader(data_set, config.get("data_loader", {})) + + # Fetch model class specified in config and create an instance of it model_cls = fetch_model_class(config) - model = model_cls(model_config=config.get("model", {}), shape=fibad_data_loader.shape()) + model = model_cls(model_config=config.get("model", {}), shape=data_loader.shape()) + + # Create trainer, a pytorch-ignite `Engine` object + trainer = _create_trainer(model) - cuda_available = torch.cuda.is_available() - mps_available = torch.backends.mps.is_available() + # Run the training process + trainer.run(dist_data_loader, max_epochs=config.get("model", {}).get("epochs", 2)) - # We don't expect mps (Apple's Metal backend) and cuda (Nvidia's backend) to ever be - # both available on the same system. - device_str = "cuda:0" if cuda_available else "cpu" - device_str = "mps" if mps_available else "cpu" + # Save the trained model + model.save() + logger.info("Finished Training") - logger.info(f"Initializing torch with device string {device_str}") - device = torch.device(device_str) - if torch.cuda.device_count() > 1: - # ~ PyTorch docs indicate that batch size should be < number of GPUs. +def _train_data_loader(data_set, config): + # ~ idist.auto_dataloader will accept a **kwargs parameter, and pass values + # ~ through to the underlying pytorch DataLoader. + # ~ Currently, our config includes unexpected keys like `name`, that cause + # ~ an exception. It would be nice to reduce this to: + # ~ `data_loader = idist.auto_dataloader(data_set, **config)` + data_loader = idist.auto_dataloader( + data_set, + batch_size=config.get("batch_size", 4), + shuffle=config.get("shuffle", True), + num_workers=config.get("num_workers", 2), + ) - # ~ PyTorch documentation recommends using torch.nn.parallel.DistributedDataParallel - # ~ instead of torch.nn.DataParallel for multi-GPU training. - # ~ See: https://pytorch.org/docs/stable/notes/cuda.html#cuda-nn-ddp-instead - model = torch.nn.DataParallel(model) + return data_loader - model.to(device) - model.train(data_loader, device=device) +def _create_trainer(model): + """This function is originally copied from here: + https://github.com/pytorch-ignite/examples/blob/main/tutorials/intermediate/cifar10-distributed.py#L164 - model.save() - logger.info("Finished Training") + It was substantially trimmed down to make it easier to understand. + + Parameters + ---------- + model : torch.nn.Module + The model to train. + + Returns + ------- + pytorch-ignite.Engine + Engine object that will be used to train the model. + """ + # Get currently available device for training, and set the model to use it + device = idist.device() + # logger.info(f"Training on device: {device}") + model = idist.auto_model(model) + + # Extract `train_step` from model, which can be wrapped after idist.auto_model(...) + if type(model) == torch.nn.parallel.DistributedDataParallel: + inner_train_step = model.module.train_step + elif type(model) == torch.nn.parallel.DataParallel: + inner_train_step = model.module.train_step + else: + inner_train_step = model.train_step + + # Wrap the `train_step` so that batch data is moved to the appropriate device + def train_step(engine, batch): + #! This feels brittle, it would be worth revisiting this. + # We assume that the batch data will generally have two forms. + # 1) A torch.Tensor that represents N samples. + # 2) A tuple (or list) of torch.Tensors, where the first tensor is the + # data, and the second is labels. + batch = batch.to(device) if isinstance(batch, torch.Tensor) else tuple(i.to(device) for i in batch) + + return inner_train_step(batch) + + # Create the ignite `Engine` object + trainer = Engine(train_step) + + @trainer.on(Events.STARTED) + def log_training_start(trainer): + logger.info(f"Training model on device: {device}") + logger.info(f"Total epochs: {trainer.state.max_epochs}") + + @trainer.on(Events.EPOCH_STARTED) + def log_epoch_start(trainer): + logger.debug(f"Starting epoch {trainer.state.epoch}") + + @trainer.on(Events.EPOCH_COMPLETED) + def log_training_loss(trainer): + logger.info(f"Epoch {trainer.state.epoch} run time: {trainer.state.times['EPOCH_COMPLETED']:.2f}[s]") + logger.info(f"Epoch {trainer.state.epoch} metrics: {trainer.state.output}") + + @trainer.on(Events.COMPLETED) + def log_total_time(trainer): + logger.info(f"Total training time: {trainer.state.times['COMPLETED']:.2f}[s]") + + return trainer