Skip to content

Commit

Permalink
update train script
Browse files Browse the repository at this point in the history
  • Loading branch information
wang-chen committed Jan 29, 2021
1 parent ff535b9 commit 6f8c627
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 25 deletions.
41 changes: 41 additions & 0 deletions evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import tqdm
import torch


def performance(loader, net, device):
net.eval()
correct, total = 0, 0
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(tqdm.tqdm(loader)):
inputs, targets = inputs.to(device), targets.to(device)
outputs = net(inputs)
_, predicted = torch.max(outputs.data, 1)
total += targets.size(0)
correct += predicted.eq(targets.data).cpu().sum().item()
acc = correct/total
return acc


if __name__ == "__main__":
import argparse
from ward import WARD
import torch.utils.data as Data

parser = argparse.ArgumentParser(description='Feature Graph Networks')
parser.add_argument("--load", type=str, required=True, help="load pretrained model file")
parser.add_argument("--device", type=str, default='cuda:0', help="cuda or cpu")
parser.add_argument("--data-root", type=str, default='/data/datasets', help="dataset location to be download")
parser.add_argument("--duration", type=int, default=50, help="duration")
parser.add_argument("--batch-size", type=int, default=100, help="minibatch size")
args = parser.parse_args(); print(args)

test_data = WARD(root=args.data_root, duration=args.duration, train=False)
test_loader = Data.DataLoader(dataset=test_data, batch_size=args.batch_size, shuffle=True, drop_last=True)
train_data = WARD(root=args.data_root, duration=args.duration, train=True)
train_loader = Data.DataLoader(dataset=train_data, batch_size=args.batch_size, shuffle=True, drop_last=True)

net = torch.load(args.load)
train_acc = performance(train_loader, net, args.device)
test_acc = performance(test_loader, net, args.device)
print("Evaluating model: ", args.load)
print("Train Acc: %f; Test Acc: %f"%(train_acc, test_acc))
16 changes: 10 additions & 6 deletions lifelong.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from gat import GAT
from fgn import FGN
from ward import WARD
from nonlifelong import performance
from evaluation import performance
from torch_util import count_parameters


Expand Down Expand Up @@ -97,8 +97,7 @@ def sample(self, inputs, targets):
parser.add_argument("--device", type=str, default='cuda:0', help="cuda or cpu")
parser.add_argument("--data-root", type=str, default='/data/datasets', help="dataset location to be download")
parser.add_argument("--model", type=str, default='FGN', help="FGN or GAT")
parser.add_argument("--load", type=str, default=None, help="load pretrained model file")
parser.add_argument("--save", type=str, default='saves/test', help="model file to save")
parser.add_argument("--save", type=str, default='saves', help="location to save model")
parser.add_argument("--optim", type=str, default='SGD', help="SGD or Adam")
parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
parser.add_argument("--duration", type=int, default=50, help="duration")
Expand All @@ -109,7 +108,7 @@ def sample(self, inputs, targets):
parser.add_argument("--seed", type=int, default=0, help='Random seed.')
args = parser.parse_args(); print(args)
os.makedirs(args.data_root, exist_ok=True)
os.makedirs('saves', exist_ok=True)
os.makedirs(args.save, exist_ok=True)
torch.manual_seed(args.seed)
Nets = {'fgn':FGN, 'gat':GAT}
Net = Nets[args.model.lower()]
Expand All @@ -129,5 +128,10 @@ def sample(self, inputs, targets):
test_acc = performance(test_loader, lgl.net, args.device)
print('Test Acc: %f'%(test_acc))
if args.save is not None:
print('Saving model to', args.save+'-%d.model'%(batch_idx))
torch.save(lgl.net, args.save+'-%d.model'%(batch_idx))
filename = args.save+'/lifelong-%s-s%d-it%d.model'%(args.model, args.seed, batch_idx+1)
print('Saving model to', filename)
torch.save(lgl.net, filename)

test_acc = performance(test_loader, lgl.net, args.device)
print('Final Test Acc: %f'%(test_acc))
torch.save(lgl.net, args.save+'/lifelong-%s-s%d.model'%(args.model, args.seed))
25 changes: 6 additions & 19 deletions nonlifelong.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,11 @@
from gat import GAT
from fgn import FGN
from ward import WARD
from evaluation import performance
from torch_util import count_parameters
from torch_util import EarlyStopScheduler


def performance(loader, net, device):
net.eval()
correct, total = 0, 0
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(tqdm.tqdm(loader)):
inputs, targets = inputs.to(device), targets.to(device)
outputs = net(inputs)
_, predicted = torch.max(outputs.data, 1)
total += targets.size(0)
correct += predicted.eq(targets.data).cpu().sum().item()
acc = correct/total
return acc


def train(loader, net, device):
net.train()
correct, total = 0, 0
Expand All @@ -54,8 +41,7 @@ def train(loader, net, device):
parser.add_argument("--device", type=str, default='cuda:0', help="cuda or cpu")
parser.add_argument("--data-root", type=str, default='/data/datasets', help="dataset location to be download")
parser.add_argument("--model", type=str, default='FGN', help="FGN or GAT")
parser.add_argument("--load", type=str, default=None, help="load pretrained model file")
parser.add_argument("--save", type=str, default='saves/test', help="model file to save")
parser.add_argument("--save", type=str, default='saves', help="location to save model")
parser.add_argument("--optim", type=str, default='SGD', help="SGD or Adam")
parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
parser.add_argument("--epoch", type=int, default=50, help="epoch")
Expand All @@ -64,7 +50,7 @@ def train(loader, net, device):
parser.add_argument("--seed", type=int, default=0, help='Random seed.')
args = parser.parse_args(); print(args)
os.makedirs(args.data_root, exist_ok=True)
os.makedirs('saves', exist_ok=True)
os.makedirs(args.save, exist_ok=True)
torch.manual_seed(args.seed)
Nets = {'fgn':FGN, 'gat':GAT}
Net = Nets[args.model.lower()]
Expand All @@ -88,8 +74,9 @@ def train(loader, net, device):

if best_acc < test_acc and args.save is not None:
best_acc = test_acc
print('Saving new best model to', args.save+'.model')
torch.save(net, args.save+'.model')
filename = args.save+'/nonlifelong-%s-s%d.model'%(args.model, args.seed)
print('Saving new best model to', filename)
torch.save(net, filename)

if scheduler.step(1-test_acc):
print('Early Stoping..')
Expand Down

0 comments on commit 6f8c627

Please sign in to comment.