-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
GAT
: Weights & Biases Tracking (#4672)
* GAT wandb example * changelog * explainer
- Loading branch information
Showing
4 changed files
with
58 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,62 +1,79 @@ | ||
import argparse | ||
import os.path as osp | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
|
||
import torch_geometric.transforms as T | ||
from torch_geometric.datasets import Planetoid | ||
from torch_geometric.logging import init_wandb, log | ||
from torch_geometric.nn import GATConv | ||
|
||
dataset = 'Cora' | ||
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset) | ||
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures()) | ||
data = dataset[0] | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--dataset', type=str, default='Cora') | ||
parser.add_argument('--hidden_channels', type=int, default=8) | ||
parser.add_argument('--heads', type=int, default=8) | ||
parser.add_argument('--lr', type=float, default=0.005) | ||
parser.add_argument('--epochs', type=int, default=200) | ||
parser.add_argument('--wandb', action='store_true', help='Track experiment') | ||
args = parser.parse_args() | ||
|
||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
init_wandb(name=f'GAT-{args.dataset}', heads=args.heads, epochs=args.epochs, | ||
hidden_channels=args.hidden_channels, lr=args.lr, device=device) | ||
|
||
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid') | ||
dataset = Planetoid(path, args.dataset, transform=T.NormalizeFeatures()) | ||
data = dataset[0].to(device) | ||
|
||
class Net(torch.nn.Module): | ||
def __init__(self, in_channels, out_channels): | ||
super().__init__() | ||
|
||
self.conv1 = GATConv(in_channels, 8, heads=8, dropout=0.6) | ||
# On the Pubmed dataset, use heads=8 in conv2. | ||
self.conv2 = GATConv(8 * 8, out_channels, heads=1, concat=False, | ||
dropout=0.6) | ||
class GAT(torch.nn.Module): | ||
def __init__(self, in_channels, hidden_channels, out_channels, heads): | ||
super().__init__() | ||
self.conv1 = GATConv(in_channels, hidden_channels, heads, dropout=0.6) | ||
# On the Pubmed dataset, use `heads` output heads in `conv2`. | ||
self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1, | ||
concat=False, dropout=0.6) | ||
|
||
def forward(self, x, edge_index): | ||
x = F.dropout(x, p=0.6, training=self.training) | ||
x = F.elu(self.conv1(x, edge_index)) | ||
x = F.dropout(x, p=0.6, training=self.training) | ||
x = self.conv2(x, edge_index) | ||
return F.log_softmax(x, dim=-1) | ||
return x | ||
|
||
|
||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
model = Net(dataset.num_features, dataset.num_classes).to(device) | ||
data = data.to(device) | ||
model = GAT(dataset.num_features, args.hidden_channels, dataset.num_classes, | ||
args.heads).to(device) | ||
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4) | ||
|
||
|
||
def train(data): | ||
def train(): | ||
model.train() | ||
optimizer.zero_grad() | ||
out = model(data.x, data.edge_index) | ||
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask]) | ||
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask]) | ||
loss.backward() | ||
optimizer.step() | ||
return float(loss) | ||
|
||
|
||
@torch.no_grad() | ||
def test(data): | ||
def test(): | ||
model.eval() | ||
out, accs = model(data.x, data.edge_index), [] | ||
for _, mask in data('train_mask', 'val_mask', 'test_mask'): | ||
acc = float((out[mask].argmax(-1) == data.y[mask]).sum() / mask.sum()) | ||
accs.append(acc) | ||
pred = model(data.x, data.edge_index).argmax(dim=-1) | ||
|
||
accs = [] | ||
for mask in [data.train_mask, data.val_mask, data.test_mask]: | ||
accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum())) | ||
return accs | ||
|
||
|
||
for epoch in range(1, 201): | ||
train(data) | ||
train_acc, val_acc, test_acc = test(data) | ||
print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, ' | ||
f'Test: {test_acc:.4f}') | ||
best_val_acc = final_test_acc = 0 | ||
for epoch in range(1, args.epochs + 1): | ||
loss = train() | ||
train_acc, val_acc, tmp_test_acc = test() | ||
if val_acc > best_val_acc: | ||
best_val_acc = val_acc | ||
test_acc = tmp_test_acc | ||
log(Epoch=epoch, Loss=loss, Train=train_acc, Val=val_acc, Test=test_acc) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters