diff --git a/src/fibad/fibad_default_config.toml b/src/fibad/fibad_default_config.toml index 28a12f0..ed5ff87 100644 --- a/src/fibad/fibad_default_config.toml +++ b/src/fibad/fibad_default_config.toml @@ -41,7 +41,7 @@ timeout = 3600 chunksize = 990 [model] -name = "ExampleAutoencoder" +name = "ExampleCNN" # An example of requesting an external model class # external_class = "user_package.submodule.ExternalModel" @@ -51,7 +51,7 @@ epochs = 10 [data_loader] # Name of data loader to use -name = "HSCDataLoader" +name = "CifarDataLoader" # An example of requesting an external data loader class # external_class = "user_package.submodule.ExternalDataLoader" diff --git a/src/fibad/models/example_cnn_classifier.py b/src/fibad/models/example_cnn_classifier.py index 16a5f15..07448ef 100644 --- a/src/fibad/models/example_cnn_classifier.py +++ b/src/fibad/models/example_cnn_classifier.py @@ -9,7 +9,6 @@ import torch.nn.functional as F # noqa N812 import torch.optim as optim -# extra long import here to address a circular import issue from .model_registry import fibad_model logger = logging.getLogger(__name__) @@ -17,7 +16,7 @@ @fibad_model class ExampleCNN(nn.Module): - def __init__(self, model_config): + def __init__(self, model_config, shape): super().__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) @@ -57,6 +56,19 @@ def train(self, trainloader, device=None): logger.info(f"[{epoch + 1}, {i + 1}] loss: {running_loss / 2000}") running_loss = 0.0 + # Creating a train_step function to be used with pytorch-ignite + # ! figure out how to pass `device` correctly!!! It shouldn't be in the method signature + # ! I just put it there to pass linting. + def train_step(self, batch, device): + inputs, labels = batch + inputs, labels = inputs.to(device), labels.to(device) + + self.optimizer.zero_grad() + outputs = self(inputs) + loss = self.criterion(outputs, labels) + loss.backward() + self.optimizer.step() + def _criterion(self): return nn.CrossEntropyLoss() diff --git a/src/fibad/train.py b/src/fibad/train.py index aa727fe..818a815 100644 --- a/src/fibad/train.py +++ b/src/fibad/train.py @@ -1,6 +1,7 @@ import logging -import torch +import ignite.distributed as idist +from ignite.engine import Engine from fibad.data_loaders.data_loader_registry import fetch_data_loader_class from fibad.models.model_registry import fetch_model_class @@ -20,33 +21,28 @@ def run(config): data_loader_cls = fetch_data_loader_class(config) fibad_data_loader = data_loader_cls(config.get("data_loader", {})) - data_loader = fibad_data_loader.get_data_loader() + data_set = fibad_data_loader.data_set() + data_loader = _train_data_loader(data_set, config.get("data_loader", {})) model_cls = fetch_model_class(config) model = model_cls(model_config=config.get("model", {}), shape=fibad_data_loader.shape()) - cuda_available = torch.cuda.is_available() - mps_available = torch.backends.mps.is_available() + model = idist.auto_model(model) - # We don't expect mps (Apple's Metal backend) and cuda (Nvidia's backend) to ever be - # both available on the same system. - device_str = "cuda:0" if cuda_available else "cpu" - device_str = "mps" if mps_available else "cpu" + trainer = Engine(model.train_step) + trainer.run(data_loader, max_epochs=config.get("model", {}).get("epochs", 2)) - logger.info(f"Initializing torch with device string {device_str}") - - device = torch.device(device_str) - if torch.cuda.device_count() > 1: - # ~ PyTorch docs indicate that batch size should be < number of GPUs. - - # ~ PyTorch documentation recommends using torch.nn.parallel.DistributedDataParallel - # ~ instead of torch.nn.DataParallel for multi-GPU training. - # ~ See: https://pytorch.org/docs/stable/notes/cuda.html#cuda-nn-ddp-instead - model = torch.nn.DataParallel(model) + logger.info("Finished Training") - model.to(device) - model.train(data_loader, device=device) +def _train_data_loader(data_set, config): + # ~ idist.auto_dataloader will accept a **kwargs parameter, and pass values + # ~ through to the underlying pytorch DataLoader. + data_loader = idist.auto_dataloader( + data_set, + batch_size=config.get("batch_size", 4), + shuffle=config.get("shuffle", True), + drop_last=config.get("drop_last", False), + ) - model.save() - logger.info("Finished Training") + return data_loader