Skip to content

Commit

Permalink
add batch_axis in validation handler (apache#17134)
Browse files Browse the repository at this point in the history
* add batch_axis in validation handler

* Add test case for batch axis support

* change test class name
  • Loading branch information
liuzh47 authored and Zheng committed Jan 21, 2020
1 parent 77adc16 commit 10c01de
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 2 deletions.
4 changes: 2 additions & 2 deletions python/mxnet/gluon/contrib/estimator/event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def batch_end(self, estimator, *args, **kwargs):
def epoch_end(self, estimator, *args, **kwargs):
self.current_epoch += 1
if self.epoch_period and self.current_epoch % self.epoch_period == 0:
self.eval_fn(val_data=self.val_data)
self.eval_fn(val_data=self.val_data, batch_axis=estimator.batch_axis)


class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, BatchEnd):
Expand Down Expand Up @@ -734,6 +734,6 @@ def batch_end(self, estimator, *args, **kwargs):
loss = [loss]
if isinstance(loss, list):
for l in loss:
batch_size += l.shape[estimator.batch_axis]
batch_size += l.shape[0]

estimator.trainer.step(batch_size)
38 changes: 38 additions & 0 deletions tests/python/unittest/test_gluon_event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,28 @@
from mxnet.gluon import nn, loss
from mxnet.gluon.contrib.estimator import estimator, event_handler
from mxnet.gluon.contrib.estimator.event_handler import LoggingHandler
from mxnet.gluon.data.dataset import Dataset
try:
from StringIO import StringIO
except ImportError:
from io import StringIO

class TestAxisArrayDataset(Dataset):
def __init__(self, * args):
self._length = len(args[1])
self._data = []
for _, data in enumerate(args):
self._data.append(data)

def __getitem__(self, idx):
if len(self._data) == 1:
return self._data[idx][0]
else:
return tuple(data[:, idx] for data in self._data)

def __len__(self):
return self._length

def _get_test_network(net=nn.Sequential()):
net.add(nn.Dense(128, activation='relu', flatten=False),
nn.Dense(64, activation='relu'),
Expand All @@ -44,6 +61,11 @@ def _get_test_data(in_size=32):
data_arr = mx.gluon.data.dataset.ArrayDataset(data, label)
return mx.gluon.data.DataLoader(data_arr, batch_size=8)

def _get_batch_axis_test_data(in_size=32):
data = nd.ones((100, in_size))
label = nd.zeros((1, in_size))
data_arr = TestAxisArrayDataset(data, label)
return mx.gluon.data.DataLoader(data_arr, batch_size=8)

def test_checkpoint_handler():
with TemporaryDirectory() as tmpdir:
Expand Down Expand Up @@ -263,3 +285,19 @@ def test_logging_interval():

assert(info_len == int(data_size/batch_size/log_interval) + 1)

def test_validation_handler_batch_axis():
# test case #1: test batch_axis=0
test_data = _get_test_data()
net = _get_test_network()
ce_loss = loss.SoftmaxCrossEntropyLoss()
acc = mx.metric.Accuracy()
est = estimator.Estimator(net, loss=ce_loss, train_metrics=acc)
est.fit(test_data, epochs=3)

#test case #2: test batch_axis=1
test_data = _get_batch_axis_test_data()
val_data = _get_batch_axis_test_data(in_size=30)
est = estimator.Estimator(net, loss=ce_loss, train_metrics=acc)
est.fit(test_data, val_data=val_data,
epochs=3, batch_axis=1)

0 comments on commit 10c01de

Please sign in to comment.