Skip to content

Commit

Permalink
added batch support to main.py, fixes #15
Browse files Browse the repository at this point in the history
  • Loading branch information
bodokaiser committed Mar 15, 2017
1 parent b84e437 commit 72e6397
Showing 1 changed file with 15 additions and 18 deletions.
33 changes: 15 additions & 18 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,10 @@ def update(images):

return update

def threshold(image):
value = np.mean(image) - 2*np.var(image)
def threshold(images):
value = images.mean() - 2*images.var()

mask = image > value
mask = torch.from_numpy(mask.astype(np.float32))

return Variable(mask)
return Variable(images.gt(value).float())

def main(args):
model = Basic()
Expand All @@ -93,7 +90,7 @@ def main(args):
os.path.join(args.datadir, f'{int(i):02d}_us.mnc'))
for i in args.train
])
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)

mr, us = dataset[120]
fixed_inputs = Variable(mr).unsqueeze(0)
Expand All @@ -114,20 +111,19 @@ def main(args):
train_loss = 0

for mr, us in dataloader:
if us.sum() > 1:
mask = threshold(us.numpy())
mask = threshold(us)

inputs = Variable(mr)
targets = Variable(us)
results = model(inputs)
inputs = Variable(mr)
targets = Variable(us)
results = model(inputs)

optimizer.zero_grad()
loss = results[0].mul(mask).dist(targets[0].mul(mask), 2)
loss.div_(mask.sum().data[0])
loss.backward()
optimizer.step()
optimizer.zero_grad()
loss = results.mul(mask).dist(targets.mul(mask), 2)
loss.div_(mask.sum().data[0])
loss.backward()
optimizer.step()

train_loss += loss.data[0]
train_loss += loss.data[0]

test_losses.append(test_loss)
train_losses.append(train_loss)
Expand All @@ -152,6 +148,7 @@ def main(args):
parser.add_argument('--train', nargs='+', default=['13'])
parser.add_argument('--epochs', type=int, nargs='?', default=20)
parser.add_argument('--datadir', type=str, nargs='?', default='mnibite')
parser.add_argument('--batch_size', type=int, nargs='?', default=64)
parser.add_argument('--show-loss', dest='show_loss', action='store_true')
parser.add_argument('--show-images', dest='show_images', action='store_true')
parser.set_defaults(show_loss=False, show_images=False)
Expand Down

0 comments on commit 72e6397

Please sign in to comment.