Skip to content

Commit

Permalink
WIP - Experimenting with pytorch ignite support for DistributedDataPa…
Browse files Browse the repository at this point in the history
…rallel.
  • Loading branch information
drewoldag committed Aug 28, 2024
1 parent 79293c8 commit a681631
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 26 deletions.
4 changes: 2 additions & 2 deletions src/fibad/fibad_default_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ timeout = 3600
chunksize = 990

[model]
name = "ExampleAutoencoder"
name = "ExampleCNN"

# An example of requesting an external model class
# external_class = "user_package.submodule.ExternalModel"
Expand All @@ -51,7 +51,7 @@ epochs = 10

[data_loader]
# Name of data loader to use
name = "HSCDataLoader"
name = "CifarDataLoader"

# An example of requesting an external data loader class
# external_class = "user_package.submodule.ExternalDataLoader"
Expand Down
16 changes: 14 additions & 2 deletions src/fibad/models/example_cnn_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@
import torch.nn.functional as F # noqa N812
import torch.optim as optim

# extra long import here to address a circular import issue
from .model_registry import fibad_model

logger = logging.getLogger(__name__)


@fibad_model
class ExampleCNN(nn.Module):
def __init__(self, model_config):
def __init__(self, model_config, shape):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
Expand Down Expand Up @@ -57,6 +56,19 @@ def train(self, trainloader, device=None):
logger.info(f"[{epoch + 1}, {i + 1}] loss: {running_loss / 2000}")
running_loss = 0.0

# Creating a train_step function to be used with pytorch-ignite
# ! figure out how to pass `device` correctly!!! It shouldn't be in the method signature
# ! I just put it there to pass linting.
def train_step(self, batch, device):
inputs, labels = batch
inputs, labels = inputs.to(device), labels.to(device)

Check warning on line 64 in src/fibad/models/example_cnn_classifier.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/models/example_cnn_classifier.py#L63-L64

Added lines #L63 - L64 were not covered by tests

self.optimizer.zero_grad()
outputs = self(inputs)
loss = self.criterion(outputs, labels)
loss.backward()
self.optimizer.step()

Check warning on line 70 in src/fibad/models/example_cnn_classifier.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/models/example_cnn_classifier.py#L66-L70

Added lines #L66 - L70 were not covered by tests

def _criterion(self):
return nn.CrossEntropyLoss()

Expand Down
40 changes: 18 additions & 22 deletions src/fibad/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging

import torch
import ignite.distributed as idist
from ignite.engine import Engine

Check warning on line 4 in src/fibad/train.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/train.py#L3-L4

Added lines #L3 - L4 were not covered by tests

from fibad.data_loaders.data_loader_registry import fetch_data_loader_class
from fibad.models.model_registry import fetch_model_class
Expand All @@ -20,33 +21,28 @@ def run(config):

data_loader_cls = fetch_data_loader_class(config)
fibad_data_loader = data_loader_cls(config.get("data_loader", {}))
data_loader = fibad_data_loader.get_data_loader()
data_set = fibad_data_loader.data_set()
data_loader = _train_data_loader(data_set, config.get("data_loader", {}))

Check warning on line 25 in src/fibad/train.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/train.py#L24-L25

Added lines #L24 - L25 were not covered by tests

model_cls = fetch_model_class(config)
model = model_cls(model_config=config.get("model", {}), shape=fibad_data_loader.shape())

cuda_available = torch.cuda.is_available()
mps_available = torch.backends.mps.is_available()
model = idist.auto_model(model)

Check warning on line 30 in src/fibad/train.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/train.py#L30

Added line #L30 was not covered by tests

# We don't expect mps (Apple's Metal backend) and cuda (Nvidia's backend) to ever be
# both available on the same system.
device_str = "cuda:0" if cuda_available else "cpu"
device_str = "mps" if mps_available else "cpu"
trainer = Engine(model.train_step)
trainer.run(data_loader, max_epochs=config.get("model", {}).get("epochs", 2))

Check warning on line 33 in src/fibad/train.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/train.py#L32-L33

Added lines #L32 - L33 were not covered by tests

logger.info(f"Initializing torch with device string {device_str}")

device = torch.device(device_str)
if torch.cuda.device_count() > 1:
# ~ PyTorch docs indicate that batch size should be < number of GPUs.

# ~ PyTorch documentation recommends using torch.nn.parallel.DistributedDataParallel
# ~ instead of torch.nn.DataParallel for multi-GPU training.
# ~ See: https://pytorch.org/docs/stable/notes/cuda.html#cuda-nn-ddp-instead
model = torch.nn.DataParallel(model)
logger.info("Finished Training")

Check warning on line 35 in src/fibad/train.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/train.py#L35

Added line #L35 was not covered by tests

model.to(device)

model.train(data_loader, device=device)
def _train_data_loader(data_set, config):

Check warning on line 38 in src/fibad/train.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/train.py#L38

Added line #L38 was not covered by tests
# ~ idist.auto_dataloader will accept a **kwargs parameter, and pass values
# ~ through to the underlying pytorch DataLoader.
data_loader = idist.auto_dataloader(

Check warning on line 41 in src/fibad/train.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/train.py#L41

Added line #L41 was not covered by tests
data_set,
batch_size=config.get("batch_size", 4),
shuffle=config.get("shuffle", True),
drop_last=config.get("drop_last", False),
)

model.save()
logger.info("Finished Training")
return data_loader

Check warning on line 48 in src/fibad/train.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/train.py#L48

Added line #L48 was not covered by tests

0 comments on commit a681631

Please sign in to comment.