Skip to content

Commit

Permalink
Add method to train NNRank
Browse files Browse the repository at this point in the history
  • Loading branch information
raviagrwl420 committed Dec 5, 2018
1 parent e04b9d1 commit 94d90b3
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 11 deletions.
15 changes: 14 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
# Main file

import torch
import itertools

import numpy as np

from src.lop import lop
from src.nn_rank import NNRank

# lop('RandB/N-p40-01').solve_instance()

model = NNRank(100, 5, 1)

inputs = torch.randn(10000).reshape(100, 100)
targets = torch.cat((torch.zeros(50), torch.ones(50)), dim=0)

lop('RandB/N-p40-01').solve_instance()
model.train(inputs, targets, 100)
27 changes: 17 additions & 10 deletions src/nn_rank.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import itertools

import torch

import torch.nn as nn
import torch.optim as optim

Expand All @@ -22,28 +26,31 @@ def __init__(self, n_in, n_hidden, n_out, bias=True):
def forward(self, inputs):
return self.model(inputs)

def train(self, inputs, targets):
def train(self, inputs, targets, num_epoch):
# Initialize criterion as MarginRankingLoss
criterion = nn.MarginRankingLoss()

# Initialize optimizer
optimizer = optim.SGD(list(self.fc1.parameters()) + list(self.fc2.parameters()),
lr=0.001, momentum=0.9)
lr=0.0001, momentum=0.9, weight_decay=0.01)

# Select all the items with target class 1 i.e good variables
item1 = torch.index_select(targets==1)
# Select all the items with target class 0 i.e. bad variables
item2 = torch.index_select(targets==0)
# Use all combinations of item1 and item2 along with the corresponding targets
modified_targets = # TODO: Compute!!

for epoch in range(NUM_EPOCH):
for epoch in range(num_epoch):
# Zero the parameter gradients
optimizer.zero_grad()

# Forward + backward + optimize
outputs = self.forward(inputs)

# Select all the items with target class 1 i.e good variables
outputs1 = torch.index_select(outputs, 0, torch.tensor(targets==1, dtype=torch.long).nonzero().reshape(-1,))
# Select all the items with target class 0 i.e. bad variables
outputs0 = torch.index_select(outputs, 0, torch.tensor(targets==0, dtype=torch.long).nonzero().reshape(-1,))
# Use all combinations of item1 and item2 along with the corresponding targets
pairs = list(itertools.product(outputs1, outputs0))
item1 = torch.stack([pair[0] for pair in pairs])
item2 = torch.stack([pair[1] for pair in pairs])
modified_targets = torch.ones(len(pairs))

# item1 is the first item from the pair
# item2 is the second item from the pair
# target is either +1 or -1 depending on the ordering of item1 and item2
Expand Down

0 comments on commit 94d90b3

Please sign in to comment.