Skip to content

Commit

Permalink
GAT: Weights & Biases Tracking (#4672)
Browse files Browse the repository at this point in the history
* GAT wandb example

* changelog

* explainer
  • Loading branch information
rusty1s authored May 18, 2022
1 parent 660a747 commit 0e2f726
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 32 deletions.
12 changes: 11 additions & 1 deletion .github/workflows/examples.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,18 @@ jobs:
run: |
pip install .[benchmark]
- name: Run examples
- name: Run GCN on Cora
run: |
python examples/gcn.py --wandb
env:
WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }}

- name: Run GAT on Cora
run: |
python examples/gat.py --wandb
env:
WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }}

- name: Run GNNExplainer
run: |
python examples/gnn_explainer.py
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.0.5] - 2022-MM-DD
### Added
- Added benchmarks via [`wandb`](https://wandb.ai/site) ([#4656](https://github.com/pyg-team/pytorch_geometric/pull/4656))
- Added benchmarks via [`wandb`](https://wandb.ai/site) ([#4656](https://github.com/pyg-team/pytorch_geometric/pull/4656), [#4672](https://github.com/pyg-team/pytorch_geometric/pull/4672))
- Added `unbatch` functionality ([#4628](https://github.com/pyg-team/pytorch_geometric/pull/4628))
- Confirm that `to_hetero()` works with custom functions, *e.g.*, `dropout_adj` ([4653](https://github.com/pyg-team/pytorch_geometric/pull/4653))
- Added the `MLP.plain_last=False` option ([4652](https://github.com/pyg-team/pytorch_geometric/pull/4652))
Expand Down
71 changes: 44 additions & 27 deletions examples/gat.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,79 @@
import argparse
import os.path as osp

import torch
import torch.nn.functional as F

import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.logging import init_wandb, log
from torch_geometric.nn import GATConv

dataset = 'Cora'
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
data = dataset[0]
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='Cora')
parser.add_argument('--hidden_channels', type=int, default=8)
parser.add_argument('--heads', type=int, default=8)
parser.add_argument('--lr', type=float, default=0.005)
parser.add_argument('--epochs', type=int, default=200)
parser.add_argument('--wandb', action='store_true', help='Track experiment')
args = parser.parse_args()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
init_wandb(name=f'GAT-{args.dataset}', heads=args.heads, epochs=args.epochs,
hidden_channels=args.hidden_channels, lr=args.lr, device=device)

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid')
dataset = Planetoid(path, args.dataset, transform=T.NormalizeFeatures())
data = dataset[0].to(device)

class Net(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()

self.conv1 = GATConv(in_channels, 8, heads=8, dropout=0.6)
# On the Pubmed dataset, use heads=8 in conv2.
self.conv2 = GATConv(8 * 8, out_channels, heads=1, concat=False,
dropout=0.6)
class GAT(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, heads):
super().__init__()
self.conv1 = GATConv(in_channels, hidden_channels, heads, dropout=0.6)
# On the Pubmed dataset, use `heads` output heads in `conv2`.
self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1,
concat=False, dropout=0.6)

def forward(self, x, edge_index):
x = F.dropout(x, p=0.6, training=self.training)
x = F.elu(self.conv1(x, edge_index))
x = F.dropout(x, p=0.6, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=-1)
return x


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net(dataset.num_features, dataset.num_classes).to(device)
data = data.to(device)
model = GAT(dataset.num_features, args.hidden_channels, dataset.num_classes,
args.heads).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)


def train(data):
def train():
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
return float(loss)


@torch.no_grad()
def test(data):
def test():
model.eval()
out, accs = model(data.x, data.edge_index), []
for _, mask in data('train_mask', 'val_mask', 'test_mask'):
acc = float((out[mask].argmax(-1) == data.y[mask]).sum() / mask.sum())
accs.append(acc)
pred = model(data.x, data.edge_index).argmax(dim=-1)

accs = []
for mask in [data.train_mask, data.val_mask, data.test_mask]:
accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))
return accs


for epoch in range(1, 201):
train(data)
train_acc, val_acc, test_acc = test(data)
print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, '
f'Test: {test_acc:.4f}')
best_val_acc = final_test_acc = 0
for epoch in range(1, args.epochs + 1):
loss = train()
train_acc, val_acc, tmp_test_acc = test()
if val_acc > best_val_acc:
best_val_acc = val_acc
test_acc = tmp_test_acc
log(Epoch=epoch, Loss=loss, Train=train_acc, Val=val_acc, Test=test_acc)
5 changes: 2 additions & 3 deletions examples/gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,11 @@ def train():
@torch.no_grad()
def test():
model.eval()
out = model(data.x, data.edge_index, data.edge_weight)
pred = model(data.x, data.edge_index, data.edge_weight).argmax(dim=-1)

accs = []
for mask in [data.train_mask, data.val_mask, data.test_mask]:
pred = out[mask].argmax(dim=-1)
accs.append(int((pred == data.y[mask]).sum()) / int(mask.sum()))
accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))
return accs


Expand Down

0 comments on commit 0e2f726

Please sign in to comment.