Skip to content

Commit

Permalink
Make combined_graph_data() more readable
Browse files Browse the repository at this point in the history
Add a few comments
Remove some indentation
Doesn't change functionality
  • Loading branch information
lukeyeager committed Jul 24, 2015
1 parent 3dae9f7 commit e487855
Showing 1 changed file with 35 additions and 27 deletions.
62 changes: 35 additions & 27 deletions digits/model/tasks/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,11 +458,15 @@ def combined_graph_data(self, cull=True):
'names': {},
}

added_train_data = False
added_val_data = False

if self.train_outputs and 'epoch' in self.train_outputs:
added_column = False
if cull:
# max 200 data points
stride = max(len(self.train_outputs['epoch'].data)/100,1)
else:
# return all data
stride = 1
for name, output in self.train_outputs.iteritems():
if name not in ['epoch', 'learning_rate']:
Expand All @@ -474,31 +478,35 @@ def combined_graph_data(self, cull=True):
data['axes'][col_id] = 'y2'
else:
data['columns'].append([col_id] + output.data[::stride])
added_column = True
if added_column:
data['columns'].append(['train_epochs'] + self.train_outputs['epoch'].data[::stride])

if self.val_outputs and 'epoch' in self.val_outputs:
added_column = False
if cull:
stride = max(len(self.val_outputs['epoch'].data)/100,1)
added_train_data = True
if added_train_data:
data['columns'].append(['train_epochs'] + self.train_outputs['epoch'].data[::stride])

if self.val_outputs and 'epoch' in self.val_outputs:
if cull:
# max 200 data points
stride = max(len(self.val_outputs['epoch'].data)/100,1)
else:
# return all data
stride = 1
for name, output in self.val_outputs.iteritems():
if name not in ['epoch']:
col_id = '%s-val' % name
data['xs'][col_id] = 'val_epochs'
data['names'][col_id] = '%s (val)' % name
if 'accuracy' in output.kind.lower():
data['columns'].append([col_id] + [100*x for x in output.data[::stride]])
data['axes'][col_id] = 'y2'
else:
stride = 1
for name, output in self.val_outputs.iteritems():
if name not in ['epoch']:
col_id = '%s-val' % name
data['xs'][col_id] = 'val_epochs'
data['names'][col_id] = '%s (val)' % name
if 'accuracy' in output.kind.lower():
data['columns'].append([col_id] + [100*x for x in output.data[::stride]])
data['axes'][col_id] = 'y2'
else:
data['columns'].append([col_id] + output.data[::stride])
added_column = True
if added_column:
data['columns'].append(['val_epochs'] + self.val_outputs['epoch'].data[::stride])
# return data if we have both training and validation data
return data
# return None if we are missing either training or validation data
return None
data['columns'].append([col_id] + output.data[::stride])
added_val_data = True
if added_val_data:
data['columns'].append(['val_epochs'] + self.val_outputs['epoch'].data[::stride])

if added_train_data:
return data
else:
# return None if only validation data exists
# helps with ordering of columns in graph
return None

0 comments on commit e487855

Please sign in to comment.