Skip to content

Commit

Permalink
Lightning Trainer integration in GraphGym (#4689)
Browse files Browse the repository at this point in the history
* add trainer

* add datamodule

* fix tests

* fix tests

* fixes

* soft import

* revert

* update

* fix tests

* reformat

* remove test split

* remove skip_train_eval

* add soft import

* update

* update

* update

* flake

Co-authored-by: rusty1s <[email protected]>
  • Loading branch information
aniketmaurya and rusty1s authored May 30, 2022
1 parent 70a3760 commit 5a4f868
Show file tree
Hide file tree
Showing 9 changed files with 89 additions and 147 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `HeteroData.subgraph()` support ([#4635](https://github.com/pyg-team/pytorch_geometric/pull/4635))
- Added the `AQSOL` dataset ([#4626](https://github.com/pyg-team/pytorch_geometric/pull/4626))
- Added `HeteroData.node_items()` and `HeteroData.edge_items()` functionality ([#4644](https://github.com/pyg-team/pytorch_geometric/pull/4644))
- Added PyTorch Lightning support in GraphGym ([#4531](https://github.com/pyg-team/pytorch_geometric/pull/4531))
- Added PyTorch Lightning support in GraphGym ([#4531](https://github.com/pyg-team/pytorch_geometric/pull/4531), [#4689](https://github.com/pyg-team/pytorch_geometric/pull/4689))
- Added support for returning embeddings in `MLP` models ([#4625](https://github.com/pyg-team/pytorch_geometric/pull/4625))
- Added faster initialization of `NeighborLoader` in case edge indices are already sorted (via `is_sorted=True`) ([#4620](https://github.com/pyg-team/pytorch_geometric/pull/4620), [#4702](https://github.com/pyg-team/pytorch_geometric/pull/4702))
- Added `AddPositionalEncoding` transform ([#4521](https://github.com/pyg-team/pytorch_geometric/pull/4521))
Expand Down
19 changes: 5 additions & 14 deletions graphgym/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,9 @@
set_out_dir,
set_run_dir,
)
from torch_geometric.graphgym.loader import create_loader
from torch_geometric.graphgym.logger import create_logger, set_printing
from torch_geometric.graphgym.loader import GraphGymDataModule
from torch_geometric.graphgym.logger import set_printing
from torch_geometric.graphgym.model_builder import create_model
from torch_geometric.graphgym.optim import create_optimizer, create_scheduler
from torch_geometric.graphgym.register import train_dict
from torch_geometric.graphgym.train import train
from torch_geometric.graphgym.utils.agg_runs import agg_runs
from torch_geometric.graphgym.utils.comp_budget import params_count
Expand All @@ -41,22 +39,15 @@
seed_everything(cfg.seed)
auto_select_device()
# Set machine learning pipeline
loaders = create_loader()
loggers = create_logger()
datamodule = GraphGymDataModule()
model = create_model()
optimizer = create_optimizer(model.parameters(), cfg.optim)
scheduler = create_scheduler(optimizer, cfg.optim)
# Print model info
logging.info(model)
logging.info(cfg)
cfg.params = params_count(model)
logging.info('Num parameters: %s', cfg.params)
# Start training
if cfg.train.mode == 'standard':
train(loggers, loaders, model, optimizer, scheduler)
else:
train_dict[cfg.train.mode](loggers, loaders, model, optimizer,
scheduler)
train(model, datamodule, logger=True)

# Aggregate results from different seeds
agg_runs(cfg.out_dir, cfg.metric_best)
# When being launched in batch mode, mark a yaml as done
Expand Down
31 changes: 9 additions & 22 deletions test/graphgym/test_graphgym.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,11 @@
set_run_dir,
)
from torch_geometric.graphgym.loader import create_loader
from torch_geometric.graphgym.logger import (
LoggerCallback,
create_logger,
set_printing,
)
from torch_geometric.graphgym.logger import LoggerCallback, set_printing
from torch_geometric.graphgym.model_builder import create_model
from torch_geometric.graphgym.models.gnn import FeatureEncoder, GNNStackStage
from torch_geometric.graphgym.models.head import GNNNodeHead
from torch_geometric.graphgym.optim import create_optimizer, create_scheduler
from torch_geometric.graphgym.train import train
from torch_geometric.graphgym.train import GraphGymDataModule, train
from torch_geometric.graphgym.utils import (
agg_runs,
auto_select_device,
Expand Down Expand Up @@ -85,11 +80,8 @@ def test_run_single_graphgym(auto_resume, skip_train_eval, use_trivial_metric):
cfg.metric_best = 'auto'
cfg.custom_metrics = []

loaders = create_loader()
assert len(loaders) == 3

loggers = create_logger()
assert len(loggers) == 3
datamodule = GraphGymDataModule()
assert len(datamodule.loaders) == 3

model = create_model()
assert isinstance(model, torch.nn.Module)
Expand All @@ -98,20 +90,15 @@ def test_run_single_graphgym(auto_resume, skip_train_eval, use_trivial_metric):
assert isinstance(model.post_mp, GNNNodeHead)
assert len(list(model.pre_mp.children())) == cfg.gnn.layers_pre_mp

optimizer = create_optimizer(model.parameters(), cfg.optim)
assert isinstance(optimizer, torch.optim.Adam)

scheduler = create_scheduler(optimizer, cfg.optim)
assert isinstance(scheduler, torch.optim.lr_scheduler.CosineAnnealingLR)
optimizer, scheduler = model.configure_optimizers()
assert isinstance(optimizer[0], torch.optim.Adam)
assert isinstance(scheduler[0], torch.optim.lr_scheduler.CosineAnnealingLR)

cfg.params = params_count(model)
assert cfg.params == 23880

train(loggers, loaders, model, optimizer, scheduler)

if use_trivial_metric:
# 6 total epochs, 4 eval epochs, 3 splits (1 training split)
assert num_trivial_metric_calls == 12 if skip_train_eval else 14
train(model, datamodule, logger=True,
trainer_config={"enable_progress_bar": False})

assert osp.isdir(get_ckpt_dir()) is cfg.train.enable_ckpt

Expand Down
3 changes: 0 additions & 3 deletions torch_geometric/graphgym/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,6 @@ def set_cfg(cfg):
# ----------------------------------------------------------------------- #
cfg.train = CN()

# Training (and validation) pipeline mode
cfg.train.mode = 'standard'

# Total graph mini-batch size
cfg.train.batch_size = 16

Expand Down
15 changes: 15 additions & 0 deletions torch_geometric/graphgym/imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import warnings

import torch

try:
import pytorch_lightning as pl
from pytorch_lightning import Callback, LightningModule
except ImportError:
# define fallbacks
pl = object
LightningModule = torch.nn.Module
Callback = object

warnings.warn("Please install 'pytorch_lightning' for using the GraphGym "
"experiment manager via 'pip install pytorch_lightning'")
10 changes: 7 additions & 3 deletions torch_geometric/graphgym/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,10 @@ def val_logger(self) -> Any:
def test_logger(self) -> Any:
return self._logger[2]

def close(self):
for logger in self._logger:
logger.close()

def _get_stats(
self,
epoch_start_time: int,
Expand Down Expand Up @@ -354,20 +358,20 @@ def on_train_epoch_end(
pl_module: 'pl.LightningModule',
):
self.train_logger.write_epoch(trainer.current_epoch)
self.train_logger.close()

def on_validation_epoch_end(
self,
trainer: 'pl.Trainer',
pl_module: 'pl.LightningModule',
):
self.val_logger.write_epoch(trainer.current_epoch)
self.val_logger.close()

def on_test_epoch_end(
self,
trainer: 'pl.Trainer',
pl_module: 'pl.LightningModule',
):
self.test_logger.write_epoch(trainer.current_epoch)
self.test_logger.close()

def on_fit_end(self, trainer, pl_module):
self.close()
11 changes: 2 additions & 9 deletions torch_geometric/graphgym/model_builder.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,15 @@
import time
import warnings
from typing import Any, Dict, Tuple

import torch

from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.imports import LightningModule
from torch_geometric.graphgym.loss import compute_loss
from torch_geometric.graphgym.models.gnn import GNN
from torch_geometric.graphgym.optim import create_optimizer, create_scheduler
from torch_geometric.graphgym.register import network_dict, register_network

try:
from pytorch_lightning import LightningModule
except ImportError:
LightningModule = torch.nn.Module
warnings.warn("Please install 'pytorch_lightning' for using the GraphGym "
"experiment manager via 'pip install pytorch_lightning'")

register_network('gnn', GNN)


Expand Down Expand Up @@ -69,7 +62,7 @@ def pre_mp(self) -> torch.nn.Module:
return self.model.pre_mp


def create_model(to_device=True, dim_in=None, dim_out=None):
def create_model(to_device=True, dim_in=None, dim_out=None) -> GraphGymModule:
r"""Create model for graph machine learning.
Args:
Expand Down
139 changes: 47 additions & 92 deletions torch_geometric/graphgym/train.py
Original file line number Diff line number Diff line change
@@ -1,95 +1,50 @@
import logging
import time
from typing import Optional

import torch
from torch.utils.data import DataLoader

from torch_geometric.graphgym.checkpoint import (
clean_ckpt,
load_ckpt,
save_ckpt,
)
from torch_geometric.data.lightning_datamodule import LightningDataModule
from torch_geometric.graphgym import create_loader
from torch_geometric.graphgym.checkpoint import get_ckpt_dir
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.loss import compute_loss
from torch_geometric.graphgym.utils.epoch import (
is_ckpt_epoch,
is_eval_epoch,
is_train_eval_epoch,
)


def train_epoch(logger, loader, model, optimizer, scheduler):
model.train()
time_start = time.time()
for batch in loader:
batch.split = 'train'
optimizer.zero_grad()
batch.to(torch.device(cfg.device))
pred, true = model(batch)
loss, pred_score = compute_loss(pred, true)
loss.backward()
optimizer.step()
logger.update_stats(true=true.detach().cpu(),
pred=pred_score.detach().cpu(), loss=loss.item(),
lr=scheduler.get_last_lr()[0],
time_used=time.time() - time_start,
params=cfg.params)
time_start = time.time()
scheduler.step()


@torch.no_grad()
def eval_epoch(logger, loader, model, split='val'):
model.eval()
time_start = time.time()
for batch in loader:
batch.split = split
batch.to(torch.device(cfg.device))
pred, true = model(batch)
loss, pred_score = compute_loss(pred, true)
logger.update_stats(true=true.detach().cpu(),
pred=pred_score.detach().cpu(), loss=loss.item(),
lr=0, time_used=time.time() - time_start,
params=cfg.params)
time_start = time.time()


def train(loggers, loaders, model, optimizer, scheduler):
"""
The core training pipeline
Args:
loggers: List of loggers
loaders: List of loaders
model: GNN model
optimizer: PyTorch optimizer
scheduler: PyTorch learning rate scheduler
"""
start_epoch = 0
if cfg.train.auto_resume:
start_epoch = load_ckpt(model, optimizer, scheduler,
cfg.train.epoch_resume)
if start_epoch == cfg.optim.max_epoch:
logging.info('Checkpoint found, Task already done')
else:
logging.info('Start from epoch {}'.format(start_epoch))

num_splits = len(loggers)
split_names = ['val', 'test']
for cur_epoch in range(start_epoch, cfg.optim.max_epoch):
train_epoch(loggers[0], loaders[0], model, optimizer, scheduler)
if is_train_eval_epoch(cur_epoch):
loggers[0].write_epoch(cur_epoch)
if is_eval_epoch(cur_epoch):
for i in range(1, num_splits):
eval_epoch(loggers[i], loaders[i], model,
split=split_names[i - 1])
loggers[i].write_epoch(cur_epoch)
if is_ckpt_epoch(cur_epoch) and cfg.train.enable_ckpt:
save_ckpt(model, optimizer, scheduler, cur_epoch)
for logger in loggers:
logger.close()
if cfg.train.ckpt_clean:
clean_ckpt()

logging.info('Task done, results saved in {}'.format(cfg.run_dir))
from torch_geometric.graphgym.imports import pl
from torch_geometric.graphgym.logger import LoggerCallback
from torch_geometric.graphgym.model_builder import GraphGymModule


class GraphGymDataModule(LightningDataModule):
def __init__(self):
self.loaders = create_loader()
super().__init__(has_val=True, has_test=True)

def train_dataloader(self) -> DataLoader:
return self.loaders[0]

def val_dataloader(self) -> DataLoader:
# better way would be to test after fit.
# First call trainer.fit(...) then trainer.test(...)
return self.loaders[1]

def test_dataloader(self) -> DataLoader:
return self.loaders[2]


def train(model: GraphGymModule, datamodule, logger: bool = True,
trainer_config: Optional[dict] = None):
callbacks = []
if logger:
callbacks.append(LoggerCallback())
if cfg.train.enable_ckpt:
ckpt_cbk = pl.callbacks.ModelCheckpoint(dirpath=get_ckpt_dir())
callbacks.append(ckpt_cbk)

trainer_config = trainer_config or {}
trainer = pl.Trainer(
**trainer_config,
enable_checkpointing=cfg.train.enable_ckpt,
callbacks=callbacks,
default_root_dir=cfg.out_dir,
max_epochs=cfg.optim.max_epoch,
)

trainer.fit(model, datamodule=datamodule)
trainer.test(model, datamodule=datamodule)
6 changes: 3 additions & 3 deletions torch_geometric/graphgym/utils/agg_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def is_seed(s):


def is_split(s):
if s in ['train', 'val', 'test']:
if s in ['train', 'val']:
return True
else:
return False
Expand Down Expand Up @@ -86,8 +86,8 @@ def agg_runs(dir, metric_best='auto'):
validation performance. Options: auto, accuracy, auc.
'''
results = {'train': None, 'val': None, 'test': None}
results_best = {'train': None, 'val': None, 'test': None}
results = {'train': None, 'val': None}
results_best = {'train': None, 'val': None}
for seed in os.listdir(dir):
if is_seed(seed):
dir_seed = os.path.join(dir, seed)
Expand Down

0 comments on commit 5a4f868

Please sign in to comment.