Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement LightningModule & LoggerCallback in GraphGym #4531

Merged
merged 43 commits into from
May 13, 2022
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
db71151
implement model steps and configure_optimizer
aniketmaurya Apr 21, 2022
7c4fbf0
fix imports
aniketmaurya Apr 21, 2022
39305a5
Merge branch 'pyg-team:master' into implement_module
aniketmaurya Apr 22, 2022
985937c
dummy pr: logger pl callback (#3)
aniketmaurya Apr 22, 2022
e4a1b62
update logger
aniketmaurya Apr 22, 2022
38d4100
add test
aniketmaurya Apr 22, 2022
bfff998
fix test
aniketmaurya Apr 22, 2022
48591cf
fix
aniketmaurya Apr 23, 2022
d297f8c
fixes
aniketmaurya Apr 24, 2022
e6d5dd1
Merge branch 'pyg-team:master' into implement_module
aniketmaurya Apr 24, 2022
fb87856
Merge branch 'pyg-team:master' into implement_module
aniketmaurya Apr 25, 2022
651eb84
add test
aniketmaurya Apr 25, 2022
e869a45
update test
aniketmaurya Apr 25, 2022
b0184fa
test configure optimizer
aniketmaurya Apr 25, 2022
2eda809
test configure optimizer
aniketmaurya Apr 25, 2022
dbbb338
Merge branch 'pyg-team:master' into implement_module
aniketmaurya Apr 25, 2022
1df4062
commit suggested change
aniketmaurya Apr 25, 2022
e7a2f46
remove redundant parameters
aniketmaurya Apr 25, 2022
3985e3a
Merge branch 'master' into implement_module
aniketmaurya Apr 28, 2022
64266c1
Merge branch 'master' into implement_module
aniketmaurya May 2, 2022
f16285b
add typing
aniketmaurya May 2, 2022
0466478
apply pr suggestions
aniketmaurya May 2, 2022
e36c481
Merge branch 'master' into implement_module
aniketmaurya May 5, 2022
c0b8006
Merge branch 'pyg-team:master' into implement_module
aniketmaurya May 8, 2022
1b2b1d9
Merge branch 'master' into implement_module
aniketmaurya May 10, 2022
23d2a13
Merge branch 'master' into implement_module
rusty1s May 10, 2022
99c9861
Merge branch 'master' into implement_module
aniketmaurya May 11, 2022
d00d4c4
apply suggestions
aniketmaurya May 11, 2022
4ef4828
Merge branch 'implement_module' of github.com:aniketmaurya/pytorch_ge…
aniketmaurya May 11, 2022
35cd7f5
apply suggestions
aniketmaurya May 11, 2022
e049293
apply suggestions
aniketmaurya May 11, 2022
8fb6761
test logger
aniketmaurya May 11, 2022
ebd5df3
remove graphgym from minimal installation
aniketmaurya May 11, 2022
9de0f1f
fix
aniketmaurya May 11, 2022
e7501c9
Merge branch 'master' into implement_module
aniketmaurya May 11, 2022
054068d
fix minimal test
aniketmaurya May 11, 2022
f9345eb
Merge branch 'master' into implement_module
aniketmaurya May 12, 2022
a819091
Merge branch 'master' into implement_module
aniketmaurya May 13, 2022
490eaad
changelog
rusty1s May 13, 2022
41e207b
update
rusty1s May 13, 2022
8f49289
update
rusty1s May 13, 2022
f09f687
update
rusty1s May 13, 2022
27afa37
lint
rusty1s May 13, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions test/graphgym/test_graphgym.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,16 @@
)
from torch_geometric.testing import withPackage

try:
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
import pytorch_lightning as pl
except ImportError:
pl = None
num_trivial_metric_calls = 0

Args = namedtuple('Args', ['cfg_file', 'opts'])
root = osp.join(osp.dirname(osp.realpath(__file__)))
args = Args(osp.join(root, 'example_node.yml'), [])


def trivial_metric(true, pred, task_type):
global num_trivial_metric_calls
Expand Down Expand Up @@ -110,3 +118,54 @@ def test_run_single_graphgym(auto_resume, skip_train_eval, use_trivial_metric):
agg_runs(cfg.out_dir, cfg.metric_best)

shutil.rmtree(cfg.out_dir)


@withPackage('yacs')
@withPackage('pytorch_lightning')
def test_graphgym_module(tmpdir):
load_cfg(cfg, args)
cfg.out_dir = osp.join(tmpdir, str(random.randrange(sys.maxsize)))
cfg.run_dir = osp.join(tmpdir, str(random.randrange(sys.maxsize)))
cfg.dataset.dir = osp.join(tmpdir, 'pyg_test_datasets', 'Planetoid')

set_out_dir(cfg.out_dir, args.cfg_file)
dump_cfg(cfg)
set_printing()

seed_everything(cfg.seed)
auto_select_device()
set_run_dir(cfg.out_dir)

loaders = create_loader()
assert len(loaders) == 3

model = create_model()
assert isinstance(model, pl.LightningModule)

optimizer, scheduler = model.configure_optimizers()
assert isinstance(optimizer[0], torch.optim.Adam)
assert isinstance(scheduler[0], torch.optim.lr_scheduler.CosineAnnealingLR)

cfg.params = params_count(model)
assert cfg.params == 23880

keys = {"loss", "true", "pred", "pred_score", "step_end_time"}
# test training step
batch = next(iter(loaders[0]))
outputs = model.training_step(batch)
assert keys.issubset(outputs.keys())
aniketmaurya marked this conversation as resolved.
Show resolved Hide resolved
assert torch.is_tensor(outputs["loss"])
aniketmaurya marked this conversation as resolved.
Show resolved Hide resolved

# test validation step
batch = next(iter(loaders[1]))
outputs = model.validation_step(batch)
assert keys.issubset(outputs.keys())
assert torch.is_tensor(outputs["loss"])

# test test step
batch = next(iter(loaders[2]))
outputs = model.test_step(batch)
assert keys.issubset(outputs.keys())
assert torch.is_tensor(outputs["loss"])

shutil.rmtree(cfg.out_dir)
11 changes: 11 additions & 0 deletions test/graphgym/test_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from torch_geometric.graphgym.logger import Logger, LoggerCallback
from torch_geometric.testing import withPackage


@withPackage('yacs')
@withPackage('pytorch_lightning')
def test_logger_callback():
logger = LoggerCallback()
assert isinstance(logger.train_logger, Logger)
assert isinstance(logger.val_logger, Logger)
assert isinstance(logger.test_logger, Logger)
115 changes: 115 additions & 0 deletions torch_geometric/graphgym/logger.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import logging
import math
import sys
import time
import warnings
from typing import Any, Dict, Optional

import torch

Expand All @@ -10,6 +13,16 @@
from torch_geometric.graphgym.utils.device import get_current_gpu_usage
from torch_geometric.graphgym.utils.io import dict_to_json, dict_to_tb

try:
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
import pytorch_lightning as pl
from pytorch_lightning import Callback

except ImportError:
pl = None
Callback = object
warnings.warn("Please install 'pytorch_lightning' for using the GraphGym "
"experiment manager via 'pip install pytorch_lightning'")


def set_printing():
"""
Expand Down Expand Up @@ -247,3 +260,105 @@ def create_logger():
for i, dataset in enumerate(range(cfg.share.num_splits)):
loggers.append(Logger(name=names[i], task_type=infer_task()))
return loggers


class LoggerCallback(Callback):
def __init__(self):
self._logger = create_logger()
self._train_epoch_start_time = None
self._val_epoch_start_time = None
self._test_epoch_start_time = None

def _create_stats(self, time_start, outputs: Dict[str, Any],
trainer: "pl.Trainer") -> Dict:
true: torch.Tensor = outputs["true"]
pred: torch.Tensor = outputs["pred"]
pred_score: torch.Tensor = outputs["pred_score"]
loss: torch.Tensor = outputs["loss"]
lr = trainer.lr_scheduler_configs[0].scheduler.get_last_lr()[0]
stats = dict(true=true.detach().cpu(), pred=pred.detach().cpu(),
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
pred_score=pred_score.detach().cpu(), loss=loss.item(),
lr=lr, time_used=time.time() - time_start,
params=cfg.params)

return stats

def on_train_epoch_start(self, trainer: "pl.Trainer",
pl_module: "pl.LightningModule") -> None:
self._train_epoch_start_time = time.time()

def on_validation_epoch_start(self, trainer: "pl.Trainer",
pl_module: "pl.LightningModule") -> None:
self._val_epoch_start_time = time.time()

def on_test_epoch_start(self, trainer: "pl.Trainer",
pl_module: "pl.LightningModule") -> None:
self._test_epoch_start_time = time.time()

def on_train_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: Dict[str, Any],
batch: Any,
batch_idx: int,
unused: int = 0,
) -> None:
stats = self._create_stats(self._train_epoch_start_time, outputs,
trainer)
self.train_logger.update_stats(**stats)

def on_validation_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: Optional[Dict[str, Any]],
batch: Any,
batch_idx: int,
dataloader_idx: int,
) -> None:
stats = self._create_stats(self._val_epoch_start_time, outputs,
trainer)
self.val_logger.update_stats(**stats)

def on_test_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: Optional[Dict[str, Any]],
batch: Any,
batch_idx: int,
dataloader_idx: int,
) -> None:
stats = self._create_stats(self._test_epoch_start_time, outputs,
trainer)
self.test_logger.update_stats(**stats)

def on_train_epoch_end(self, trainer: "pl.Trainer",
pl_module: "pl.LightningModule") -> None:
self.train_logger.write_epoch(trainer.current_epoch)

def on_validation_epoch_end(self, trainer: "pl.Trainer",
pl_module: "pl.LightningModule") -> None:
self.val_logger.write_epoch(trainer.current_epoch)

def on_test_epoch_end(self, trainer: "pl.Trainer",
pl_module: "pl.LightningModule") -> None:
self.test_logger.write_epoch(trainer.current_epoch)

def on_epoch_end(self, trainer: "pl.Trainer",
pl_module: "pl.LightningModule") -> None:
for logger in self._logger:
logger.close()

@property
def train_logger(self):
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
return self._logger[0]

@property
def val_logger(self):
return self._logger[1]

@property
def test_logger(self):
return self._logger[2]
33 changes: 32 additions & 1 deletion torch_geometric/graphgym/model_builder.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
import time
import typing
import warnings
from typing import Any, Dict, Tuple

import torch

from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.loss import compute_loss
from torch_geometric.graphgym.models.gnn import GNN
from torch_geometric.graphgym.optim import create_optimizer, create_scheduler
from torch_geometric.graphgym.register import network_dict, register_network

if typing.TYPE_CHECKING:
from yacs.config import CfgNode

try:
from pytorch_lightning import LightningModule
except ImportError:
Expand All @@ -17,14 +25,37 @@


class GraphGymModule(LightningModule):
def __init__(self, dim_in, dim_out, cfg):
def __init__(self, dim_in, dim_out, cfg: "CfgNode"):
super().__init__()
self.cfg = cfg
self.model = network_dict[cfg.model.type](dim_in=dim_in,
dim_out=dim_out)

def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)

def configure_optimizers(self) -> Tuple[Any, Any]:
optimizer = create_optimizer(self.model.parameters(), self.cfg.optim)
scheduler = create_scheduler(optimizer, self.cfg.optim)
return [optimizer], [scheduler]

def _shared_step(self, batch, split: str) -> Dict:
batch.split = split
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
pred, true = self.forward(batch)
aniketmaurya marked this conversation as resolved.
Show resolved Hide resolved
loss, pred_score = compute_loss(pred, true)
step_end_time = time.time()
return dict(loss=loss, true=true, pred=pred, pred_score=pred_score,
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
step_end_time=step_end_time)

def training_step(self, batch, *args, **kwargs):
return self._shared_step(batch, "train")

def validation_step(self, batch, *args, **kwargs):
return self._shared_step(batch, "val")

def test_step(self, batch, *args, **kwargs):
return self._shared_step(batch, "test")

@property
def encoder(self) -> torch.nn.Module:
return self.model.encoder
Expand Down