-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Lightning Trainer integration in GraphGym (#4689)
* add trainer * add datamodule * fix tests * fix tests * fixes * soft import * revert * update * fix tests * reformat * remove test split * remove skip_train_eval * add soft import * update * update * update * flake Co-authored-by: rusty1s <[email protected]>
- Loading branch information
1 parent
70a3760
commit 5a4f868
Showing
9 changed files
with
89 additions
and
147 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import warnings | ||
|
||
import torch | ||
|
||
try: | ||
import pytorch_lightning as pl | ||
from pytorch_lightning import Callback, LightningModule | ||
except ImportError: | ||
# define fallbacks | ||
pl = object | ||
LightningModule = torch.nn.Module | ||
Callback = object | ||
|
||
warnings.warn("Please install 'pytorch_lightning' for using the GraphGym " | ||
"experiment manager via 'pip install pytorch_lightning'") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,95 +1,50 @@ | ||
import logging | ||
import time | ||
from typing import Optional | ||
|
||
import torch | ||
from torch.utils.data import DataLoader | ||
|
||
from torch_geometric.graphgym.checkpoint import ( | ||
clean_ckpt, | ||
load_ckpt, | ||
save_ckpt, | ||
) | ||
from torch_geometric.data.lightning_datamodule import LightningDataModule | ||
from torch_geometric.graphgym import create_loader | ||
from torch_geometric.graphgym.checkpoint import get_ckpt_dir | ||
from torch_geometric.graphgym.config import cfg | ||
from torch_geometric.graphgym.loss import compute_loss | ||
from torch_geometric.graphgym.utils.epoch import ( | ||
is_ckpt_epoch, | ||
is_eval_epoch, | ||
is_train_eval_epoch, | ||
) | ||
|
||
|
||
def train_epoch(logger, loader, model, optimizer, scheduler): | ||
model.train() | ||
time_start = time.time() | ||
for batch in loader: | ||
batch.split = 'train' | ||
optimizer.zero_grad() | ||
batch.to(torch.device(cfg.device)) | ||
pred, true = model(batch) | ||
loss, pred_score = compute_loss(pred, true) | ||
loss.backward() | ||
optimizer.step() | ||
logger.update_stats(true=true.detach().cpu(), | ||
pred=pred_score.detach().cpu(), loss=loss.item(), | ||
lr=scheduler.get_last_lr()[0], | ||
time_used=time.time() - time_start, | ||
params=cfg.params) | ||
time_start = time.time() | ||
scheduler.step() | ||
|
||
|
||
@torch.no_grad() | ||
def eval_epoch(logger, loader, model, split='val'): | ||
model.eval() | ||
time_start = time.time() | ||
for batch in loader: | ||
batch.split = split | ||
batch.to(torch.device(cfg.device)) | ||
pred, true = model(batch) | ||
loss, pred_score = compute_loss(pred, true) | ||
logger.update_stats(true=true.detach().cpu(), | ||
pred=pred_score.detach().cpu(), loss=loss.item(), | ||
lr=0, time_used=time.time() - time_start, | ||
params=cfg.params) | ||
time_start = time.time() | ||
|
||
|
||
def train(loggers, loaders, model, optimizer, scheduler): | ||
""" | ||
The core training pipeline | ||
Args: | ||
loggers: List of loggers | ||
loaders: List of loaders | ||
model: GNN model | ||
optimizer: PyTorch optimizer | ||
scheduler: PyTorch learning rate scheduler | ||
""" | ||
start_epoch = 0 | ||
if cfg.train.auto_resume: | ||
start_epoch = load_ckpt(model, optimizer, scheduler, | ||
cfg.train.epoch_resume) | ||
if start_epoch == cfg.optim.max_epoch: | ||
logging.info('Checkpoint found, Task already done') | ||
else: | ||
logging.info('Start from epoch {}'.format(start_epoch)) | ||
|
||
num_splits = len(loggers) | ||
split_names = ['val', 'test'] | ||
for cur_epoch in range(start_epoch, cfg.optim.max_epoch): | ||
train_epoch(loggers[0], loaders[0], model, optimizer, scheduler) | ||
if is_train_eval_epoch(cur_epoch): | ||
loggers[0].write_epoch(cur_epoch) | ||
if is_eval_epoch(cur_epoch): | ||
for i in range(1, num_splits): | ||
eval_epoch(loggers[i], loaders[i], model, | ||
split=split_names[i - 1]) | ||
loggers[i].write_epoch(cur_epoch) | ||
if is_ckpt_epoch(cur_epoch) and cfg.train.enable_ckpt: | ||
save_ckpt(model, optimizer, scheduler, cur_epoch) | ||
for logger in loggers: | ||
logger.close() | ||
if cfg.train.ckpt_clean: | ||
clean_ckpt() | ||
|
||
logging.info('Task done, results saved in {}'.format(cfg.run_dir)) | ||
from torch_geometric.graphgym.imports import pl | ||
from torch_geometric.graphgym.logger import LoggerCallback | ||
from torch_geometric.graphgym.model_builder import GraphGymModule | ||
|
||
|
||
class GraphGymDataModule(LightningDataModule): | ||
def __init__(self): | ||
self.loaders = create_loader() | ||
super().__init__(has_val=True, has_test=True) | ||
|
||
def train_dataloader(self) -> DataLoader: | ||
return self.loaders[0] | ||
|
||
def val_dataloader(self) -> DataLoader: | ||
# better way would be to test after fit. | ||
# First call trainer.fit(...) then trainer.test(...) | ||
return self.loaders[1] | ||
|
||
def test_dataloader(self) -> DataLoader: | ||
return self.loaders[2] | ||
|
||
|
||
def train(model: GraphGymModule, datamodule, logger: bool = True, | ||
trainer_config: Optional[dict] = None): | ||
callbacks = [] | ||
if logger: | ||
callbacks.append(LoggerCallback()) | ||
if cfg.train.enable_ckpt: | ||
ckpt_cbk = pl.callbacks.ModelCheckpoint(dirpath=get_ckpt_dir()) | ||
callbacks.append(ckpt_cbk) | ||
|
||
trainer_config = trainer_config or {} | ||
trainer = pl.Trainer( | ||
**trainer_config, | ||
enable_checkpointing=cfg.train.enable_ckpt, | ||
callbacks=callbacks, | ||
default_root_dir=cfg.out_dir, | ||
max_epochs=cfg.optim.max_epoch, | ||
) | ||
|
||
trainer.fit(model, datamodule=datamodule) | ||
trainer.test(model, datamodule=datamodule) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters