Skip to content

Commit

Permalink
Implement LightningModule & LoggerCallback in GraphGym (#4531)
Browse files Browse the repository at this point in the history
* implement model steps and configure_optimizer

* fix imports

* dummy pr: logger pl callback (#3)

implement logger pl callback

* update logger

* add test

* fix test

* fix

* fixes

* add test

* update test

* test configure optimizer

* test configure optimizer

* commit suggested change

Co-authored-by: Jirka Borovec <[email protected]>

* remove redundant parameters

* add typing

* apply pr suggestions

* apply suggestions

* apply suggestions

* apply suggestions

* test logger

* remove graphgym from minimal installation

* fix

* fix minimal test

* changelog

* update

* update

* update

* lint

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Matthias Fey <[email protected]>
  • Loading branch information
3 people authored May 13, 2022
1 parent 6fd6f5b commit c55729b
Show file tree
Hide file tree
Showing 5 changed files with 262 additions and 9 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.0.5] - 2022-MM-DD
### Added
- Added PyTorch Lightning support in GraphGym ([#4531](https://github.com/pyg-team/pytorch_geometric/pull/4531))
- Added support for returning embeddings in `MLP` models ([#4625](https://github.com/pyg-team/pytorch_geometric/pull/4625))
- Added faster initialization of `NeighborLoader` in case edge indices are already sorted (via `is_sorted=True`) ([#4620](https://github.com/pyg-team/pytorch_geometric/pull/4620))
- Added `AddPositionalEncoding` transform ([#4521](https://github.com/pyg-team/pytorch_geometric/pull/4521))
Expand Down
93 changes: 92 additions & 1 deletion test/graphgym/test_graphgym.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
set_run_dir,
)
from torch_geometric.graphgym.loader import create_loader
from torch_geometric.graphgym.logger import create_logger, set_printing
from torch_geometric.graphgym.logger import (
LoggerCallback,
create_logger,
set_printing,
)
from torch_geometric.graphgym.model_builder import create_model
from torch_geometric.graphgym.models.gnn import FeatureEncoder, GNNStackStage
from torch_geometric.graphgym.models.head import GNNNodeHead
Expand All @@ -33,6 +37,10 @@

num_trivial_metric_calls = 0

Args = namedtuple('Args', ['cfg_file', 'opts'])
root = osp.join(osp.dirname(osp.realpath(__file__)))
args = Args(osp.join(root, 'example_node.yml'), [])


def trivial_metric(true, pred, task_type):
global num_trivial_metric_calls
Expand Down Expand Up @@ -110,3 +118,86 @@ def test_run_single_graphgym(auto_resume, skip_train_eval, use_trivial_metric):
agg_runs(cfg.out_dir, cfg.metric_best)

shutil.rmtree(cfg.out_dir)


@withPackage('yacs')
@withPackage('pytorch_lightning')
def test_graphgym_module(tmpdir):
import pytorch_lightning as pl

load_cfg(cfg, args)
cfg.out_dir = osp.join(tmpdir, str(random.randrange(sys.maxsize)))
cfg.run_dir = osp.join(tmpdir, str(random.randrange(sys.maxsize)))
cfg.dataset.dir = osp.join(tmpdir, 'pyg_test_datasets', 'Planetoid')

set_out_dir(cfg.out_dir, args.cfg_file)
dump_cfg(cfg)
set_printing()

seed_everything(cfg.seed)
auto_select_device()
set_run_dir(cfg.out_dir)

loaders = create_loader()
assert len(loaders) == 3

model = create_model()
assert isinstance(model, pl.LightningModule)

optimizer, scheduler = model.configure_optimizers()
assert isinstance(optimizer[0], torch.optim.Adam)
assert isinstance(scheduler[0], torch.optim.lr_scheduler.CosineAnnealingLR)

cfg.params = params_count(model)
assert cfg.params == 23880

keys = {"loss", "true", "pred_score", "step_end_time"}
# test training step
batch = next(iter(loaders[0]))
outputs = model.training_step(batch)
assert keys == set(outputs.keys())
assert isinstance(outputs["loss"], torch.Tensor)

# test validation step
batch = next(iter(loaders[1]))
outputs = model.validation_step(batch)
assert keys == set(outputs.keys())
assert isinstance(outputs["loss"], torch.Tensor)

# test test step
batch = next(iter(loaders[2]))
outputs = model.test_step(batch)
assert keys == set(outputs.keys())
assert isinstance(outputs["loss"], torch.Tensor)

shutil.rmtree(cfg.out_dir)


@withPackage('yacs')
@withPackage('pytorch_lightning')
def test_train(tmpdir):
import pytorch_lightning as pl

load_cfg(cfg, args)
cfg.out_dir = osp.join(tmpdir, str(random.randrange(sys.maxsize)))
cfg.run_dir = osp.join(tmpdir, str(random.randrange(sys.maxsize)))
cfg.dataset.dir = osp.join(tmpdir, 'pyg_test_datasets', 'Planetoid')

set_out_dir(cfg.out_dir, args.cfg_file)
dump_cfg(cfg)
set_printing()

seed_everything(cfg.seed)
auto_select_device()
set_run_dir(cfg.out_dir)

loaders = create_loader()
model = create_model()
cfg.params = params_count(model)
logger = LoggerCallback()
trainer = pl.Trainer(max_epochs=1, max_steps=4, callbacks=logger,
log_every_n_steps=1)
train_loader, val_loader = loaders[0], loaders[1]
trainer.fit(model, train_loader, val_loader)

shutil.rmtree(cfg.out_dir)
11 changes: 11 additions & 0 deletions test/graphgym/test_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from torch_geometric.graphgym.logger import Logger, LoggerCallback
from torch_geometric.testing import withPackage


@withPackage('yacs')
@withPackage('pytorch_lightning')
def test_logger_callback():
logger = LoggerCallback()
assert isinstance(logger.train_logger, Logger)
assert isinstance(logger.val_logger, Logger)
assert isinstance(logger.test_logger, Logger)
136 changes: 130 additions & 6 deletions torch_geometric/graphgym/logger.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import logging
import math
import sys
import time
import warnings
from typing import Any, Dict, Optional

import torch

Expand All @@ -10,6 +13,16 @@
from torch_geometric.graphgym.utils.device import get_current_gpu_usage
from torch_geometric.graphgym.utils.io import dict_to_json, dict_to_tb

try:
import pytorch_lightning as pl
from pytorch_lightning import Callback

except ImportError:
pl = None
Callback = object
warnings.warn("Please install 'pytorch_lightning' for using the GraphGym "
"experiment manager via 'pip install pytorch_lightning'")


def set_printing():
"""
Expand Down Expand Up @@ -236,14 +249,125 @@ def infer_task():


def create_logger():
"""
Create logger for the experiment
Returns: List of logger objects
"""
r"""Create logger for the experiment."""
loggers = []
names = ['train', 'val', 'test']
for i, dataset in enumerate(range(cfg.share.num_splits)):
loggers.append(Logger(name=names[i], task_type=infer_task()))
return loggers


class LoggerCallback(Callback):
def __init__(self):
self._logger = create_logger()
self._train_epoch_start_time = None
self._val_epoch_start_time = None
self._test_epoch_start_time = None

@property
def train_logger(self) -> Any:
return self._logger[0]

@property
def val_logger(self) -> Any:
return self._logger[1]

@property
def test_logger(self) -> Any:
return self._logger[2]

def _get_stats(
self,
epoch_start_time: int,
outputs: Dict[str, Any],
trainer: 'pl.Trainer',
) -> Dict:
return dict(
true=outputs['true'].detach().cpu(),
pred=outputs['pred_score'].detach().cpu(),
loss=float(outputs['loss']),
lr=trainer.lr_scheduler_configs[0].scheduler.get_last_lr()[0],
time_used=time.time() - epoch_start_time,
params=cfg.params,
)

def on_train_epoch_start(
self,
trainer: 'pl.Trainer',
pl_module: 'pl.LightningModule',
):
self._train_epoch_start_time = time.time()

def on_validation_epoch_start(
self,
trainer: 'pl.Trainer',
pl_module: 'pl.LightningModule',
):
self._val_epoch_start_time = time.time()

def on_test_epoch_start(
self,
trainer: 'pl.Trainer',
pl_module: 'pl.LightningModule',
):
self._test_epoch_start_time = time.time()

def on_train_batch_end(
self,
trainer: 'pl.Trainer',
pl_module: 'pl.LightningModule',
outputs: Dict[str, Any],
batch: Any,
batch_idx: int,
unused: int = 0,
) -> None:
stats = self._get_stats(self._train_epoch_start_time, outputs, trainer)
self.train_logger.update_stats(**stats)

def on_validation_batch_end(
self,
trainer: 'pl.Trainer',
pl_module: 'pl.LightningModule',
outputs: Optional[Dict[str, Any]],
batch: Any,
batch_idx: int,
dataloader_idx: int,
) -> None:
stats = self._get_stats(self._val_epoch_start_time, outputs, trainer)
self.val_logger.update_stats(**stats)

def on_test_batch_end(
self,
trainer: 'pl.Trainer',
pl_module: 'pl.LightningModule',
outputs: Optional[Dict[str, Any]],
batch: Any,
batch_idx: int,
dataloader_idx: int,
) -> None:
stats = self._get_stats(self._test_epoch_start_time, outputs, trainer)
self.test_logger.update_stats(**stats)

def on_train_epoch_end(
self,
trainer: 'pl.Trainer',
pl_module: 'pl.LightningModule',
):
self.train_logger.write_epoch(trainer.current_epoch)
self.train_logger.close()

def on_validation_epoch_end(
self,
trainer: 'pl.Trainer',
pl_module: 'pl.LightningModule',
):
self.val_logger.write_epoch(trainer.current_epoch)
self.val_logger.close()

def on_test_epoch_end(
self,
trainer: 'pl.Trainer',
pl_module: 'pl.LightningModule',
):
self.test_logger.write_epoch(trainer.current_epoch)
self.test_logger.close()
30 changes: 28 additions & 2 deletions torch_geometric/graphgym/model_builder.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import time
import warnings
from typing import Any, Dict, Tuple

import torch

from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.loss import compute_loss
from torch_geometric.graphgym.models.gnn import GNN
from torch_geometric.graphgym.optim import create_optimizer, create_scheduler
from torch_geometric.graphgym.register import network_dict, register_network

try:
Expand All @@ -19,12 +23,35 @@
class GraphGymModule(LightningModule):
def __init__(self, dim_in, dim_out, cfg):
super().__init__()
self.cfg = cfg
self.model = network_dict[cfg.model.type](dim_in=dim_in,
dim_out=dim_out)

def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)

def configure_optimizers(self) -> Tuple[Any, Any]:
optimizer = create_optimizer(self.model.parameters(), self.cfg.optim)
scheduler = create_scheduler(optimizer, self.cfg.optim)
return [optimizer], [scheduler]

def _shared_step(self, batch, split: str) -> Dict:
batch.split = split
pred, true = self(batch)
loss, pred_score = compute_loss(pred, true)
step_end_time = time.time()
return dict(loss=loss, true=true, pred_score=pred_score,
step_end_time=step_end_time)

def training_step(self, batch, *args, **kwargs):
return self._shared_step(batch, split="train")

def validation_step(self, batch, *args, **kwargs):
return self._shared_step(batch, split="val")

def test_step(self, batch, *args, **kwargs):
return self._shared_step(batch, split="test")

@property
def encoder(self) -> torch.nn.Module:
return self.model.encoder
Expand All @@ -43,8 +70,7 @@ def pre_mp(self) -> torch.nn.Module:


def create_model(to_device=True, dim_in=None, dim_out=None):
r"""
Create model for graph machine learning
r"""Create model for graph machine learning.
Args:
to_device (string): The devide that the model will be transferred to
Expand Down

0 comments on commit c55729b

Please sign in to comment.