Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Revert "Change the way NDArrayIter handle the last batch" #12537

Merged
merged 9 commits into from
Sep 12, 2018
1 change: 0 additions & 1 deletion CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,4 +178,3 @@ List of Contributors
* [Aaron Markham](https://github.com/aaronmarkham)
* [Sam Skalicky](https://github.com/samskalicky)
* [Per Goncalves da Silva](https://github.com/perdasilva)
* [Cheng-Che Lee](https://github.com/stu1130)
280 changes: 137 additions & 143 deletions python/mxnet/io/io.py → python/mxnet/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,30 @@

"""Data iterators for common data formats."""
from __future__ import absolute_import
from collections import namedtuple
from collections import OrderedDict, namedtuple

import sys
import ctypes
import logging
import threading
try:
import h5py
except ImportError:
h5py = None
import numpy as np

from ..base import _LIB
from ..base import c_str_array, mx_uint, py_str
from ..base import DataIterHandle, NDArrayHandle
from ..base import mx_real_t
from ..base import check_call, build_param_doc as _build_param_doc
from ..ndarray import NDArray
from ..ndarray.sparse import CSRNDArray
from ..ndarray import _ndarray_cls
from ..ndarray import array
from ..ndarray import concat

from .utils import init_data, has_instance, getdata_by_idx
from .base import _LIB
from .base import c_str_array, mx_uint, py_str
from .base import DataIterHandle, NDArrayHandle
from .base import mx_real_t
from .base import check_call, build_param_doc as _build_param_doc
from .ndarray import NDArray
from .ndarray.sparse import CSRNDArray
from .ndarray.sparse import array as sparse_array
from .ndarray import _ndarray_cls
from .ndarray import array
from .ndarray import concatenate
from .ndarray import arange
from .ndarray.random import shuffle as random_shuffle

class DataDesc(namedtuple('DataDesc', ['name', 'shape'])):
"""DataDesc is used to store name, shape, type and layout
Expand Down Expand Up @@ -485,6 +489,59 @@ def getindex(self):
def getpad(self):
return self.current_batch.pad

def _init_data(data, allow_empty, default_name):
"""Convert data into canonical form."""
assert (data is not None) or allow_empty
if data is None:
data = []

if isinstance(data, (np.ndarray, NDArray, h5py.Dataset)
if h5py else (np.ndarray, NDArray)):
data = [data]
if isinstance(data, list):
if not allow_empty:
assert(len(data) > 0)
if len(data) == 1:
data = OrderedDict([(default_name, data[0])]) # pylint: disable=redefined-variable-type
else:
data = OrderedDict( # pylint: disable=redefined-variable-type
[('_%d_%s' % (i, default_name), d) for i, d in enumerate(data)])
if not isinstance(data, dict):
raise TypeError("Input must be NDArray, numpy.ndarray, h5py.Dataset " + \
"a list of them or dict with them as values")
for k, v in data.items():
if not isinstance(v, (NDArray, h5py.Dataset) if h5py else NDArray):
try:
data[k] = array(v)
except:
raise TypeError(("Invalid type '%s' for %s, " % (type(v), k)) + \
"should be NDArray, numpy.ndarray or h5py.Dataset")

return list(sorted(data.items()))

def _has_instance(data, dtype):
"""Return True if ``data`` has instance of ``dtype``.
This function is called after _init_data.
``data`` is a list of (str, NDArray)"""
for item in data:
_, arr = item
if isinstance(arr, dtype):
return True
return False

def _shuffle(data, idx):
"""Shuffle the data."""
shuffle_data = []

for k, v in data:
if (isinstance(v, h5py.Dataset) if h5py else False):
shuffle_data.append((k, v))
elif isinstance(v, CSRNDArray):
shuffle_data.append((k, sparse_array(v.asscipy()[idx], v.context)))
else:
shuffle_data.append((k, array(v.asnumpy()[idx], v.context)))

return shuffle_data

class NDArrayIter(DataIter):
"""Returns an iterator for ``mx.nd.NDArray``, ``numpy.ndarray``, ``h5py.Dataset``
Expand Down Expand Up @@ -544,22 +601,6 @@ class NDArrayIter(DataIter):
...
>>> batchidx # Remaining examples are discarded. So, 10/3 batches are created.
3
>>> dataiter = mx.io.NDArrayIter(data, labels, 3, False, last_batch_handle='roll_over')
>>> batchidx = 0
>>> for batch in dataiter:
... batchidx += 1
...
>>> batchidx # Remaining examples are rolled over to the next iteration.
3
>>> dataiter.reset()
>>> dataiter.next().data[0].asnumpy()
[[[ 36. 37.]
[ 38. 39.]]
[[ 0. 1.]
[ 2. 3.]]
[[ 4. 5.]
[ 6. 7.]]]
(3L, 2L, 2L)

`NDArrayIter` also supports multiple input and labels.

Expand Down Expand Up @@ -592,11 +633,8 @@ class NDArrayIter(DataIter):
Only supported if no h5py.Dataset inputs are used.
last_batch_handle : str, optional
How to handle the last batch. This parameter can be 'pad', 'discard' or
'roll_over'.
If 'pad', the last batch will be padded with data starting from the begining
If 'discard', the last batch will be discarded
If 'roll_over', the remaining elements will be rolled over to the next iteration and
note that it is intended for training and can cause problems if used for prediction.
'roll_over'. 'roll_over' is intended for training and can cause problems
if used for prediction.
data_name : str, optional
The data name.
label_name : str, optional
Expand All @@ -607,28 +645,36 @@ def __init__(self, data, label=None, batch_size=1, shuffle=False,
label_name='softmax_label'):
super(NDArrayIter, self).__init__(batch_size)

self.data = init_data(data, allow_empty=False, default_name=data_name)
self.label = init_data(label, allow_empty=True, default_name=label_name)
self.data = _init_data(data, allow_empty=False, default_name=data_name)
self.label = _init_data(label, allow_empty=True, default_name=label_name)

if ((has_instance(self.data, CSRNDArray) or has_instance(self.label, CSRNDArray)) and
if ((_has_instance(self.data, CSRNDArray) or _has_instance(self.label, CSRNDArray)) and
(last_batch_handle != 'discard')):
raise NotImplementedError("`NDArrayIter` only supports ``CSRNDArray``" \
" with `last_batch_handle` set to `discard`.")

self.idx = np.arange(self.data[0][1].shape[0])
self.shuffle = shuffle
self.last_batch_handle = last_batch_handle
self.batch_size = batch_size
self.cursor = -self.batch_size
self.num_data = self.idx.shape[0]
# shuffle
self.reset()
# shuffle data
if shuffle:
tmp_idx = arange(self.data[0][1].shape[0], dtype=np.int32)
self.idx = random_shuffle(tmp_idx, out=tmp_idx).asnumpy()
self.data = _shuffle(self.data, self.idx)
self.label = _shuffle(self.label, self.idx)
else:
self.idx = np.arange(self.data[0][1].shape[0])

# batching
if last_batch_handle == 'discard':
new_n = self.data[0][1].shape[0] - self.data[0][1].shape[0] % batch_size
self.idx = self.idx[:new_n]

self.data_list = [x[1] for x in self.data] + [x[1] for x in self.label]
self.num_source = len(self.data_list)
# used for 'roll_over'
self._cache_data = None
self._cache_label = None
self.num_data = self.idx.shape[0]
assert self.num_data >= batch_size, \
"batch_size needs to be smaller than data size."
self.cursor = -batch_size
self.batch_size = batch_size
self.last_batch_handle = last_batch_handle

@property
def provide_data(self):
Expand All @@ -648,126 +694,74 @@ def provide_label(self):

def hard_reset(self):
"""Ignore roll over data and set to start."""
if self.shuffle:
self._shuffle_data()
self.cursor = -self.batch_size
self._cache_data = None
self._cache_label = None

def reset(self):
"""Resets the iterator to the beginning of the data."""
if self.shuffle:
self._shuffle_data()
# the range below indicate the last batch
if self.last_batch_handle == 'roll_over' and \
self.num_data - self.batch_size < self.cursor < self.num_data:
# (self.cursor - self.num_data) represents the data we have for the last batch
self.cursor = self.cursor - self.num_data - self.batch_size
if self.last_batch_handle == 'roll_over' and self.cursor > self.num_data:
self.cursor = -self.batch_size + (self.cursor%self.num_data)%self.batch_size
else:
self.cursor = -self.batch_size

def iter_next(self):
"""Increments the coursor by batch_size for next batch
and check current cursor if it exceed the number of data points."""
self.cursor += self.batch_size
return self.cursor < self.num_data

def next(self):
"""Returns the next batch of data."""
if not self.iter_next():
raise StopIteration
data = self.getdata()
label = self.getlabel()
# iter should stop when last batch is not complete
if data[0].shape[0] != self.batch_size:
# in this case, cache it for next epoch
self._cache_data = data
self._cache_label = label
if self.iter_next():
return DataBatch(data=self.getdata(), label=self.getlabel(), \
pad=self.getpad(), index=None)
else:
raise StopIteration
return DataBatch(data=data, label=label, \
pad=self.getpad(), index=None)

def _getdata(self, data_source, start=None, end=None):
"""Load data from underlying arrays."""
assert start is not None or end is not None, 'should at least specify start or end'
start = start if start is not None else 0
end = end if end is not None else data_source[0][1].shape[0]
s = slice(start, end)
return [
x[1][s]
if isinstance(x[1], (np.ndarray, NDArray)) else
# h5py (only supports indices in increasing order)
array(x[1][sorted(self.idx[s])][[
list(self.idx[s]).index(i)
for i in sorted(self.idx[s])
]]) for x in data_source
]

def _concat(self, first_data, second_data):
"""Helper function to concat two NDArrays."""
return [
concat(first_data[0], second_data[0], dim=0)
]

def _batchify(self, data_source):
def _getdata(self, data_source):
"""Load data from underlying arrays, internal use only."""
assert self.cursor < self.num_data, 'DataIter needs reset.'
# first batch of next epoch with 'roll_over'
if self.last_batch_handle == 'roll_over' and \
-self.batch_size < self.cursor < 0:
assert self._cache_data is not None or self._cache_label is not None, \
'next epoch should have cached data'
cache_data = self._cache_data if self._cache_data is not None else self._cache_label
second_data = self._getdata(
data_source, end=self.cursor + self.batch_size)
if self._cache_data is not None:
self._cache_data = None
else:
self._cache_label = None
return self._concat(cache_data, second_data)
# last batch with 'pad'
elif self.last_batch_handle == 'pad' and \
self.cursor + self.batch_size > self.num_data:
pad = self.batch_size - self.num_data + self.cursor
first_data = self._getdata(data_source, start=self.cursor)
second_data = self._getdata(data_source, end=pad)
return self._concat(first_data, second_data)
# normal case
assert(self.cursor < self.num_data), "DataIter needs reset."
if self.cursor + self.batch_size <= self.num_data:
return [
# np.ndarray or NDArray case
x[1][self.cursor:self.cursor + self.batch_size]
if isinstance(x[1], (np.ndarray, NDArray)) else
# h5py (only supports indices in increasing order)
array(x[1][sorted(self.idx[
self.cursor:self.cursor + self.batch_size])][[
list(self.idx[self.cursor:
self.cursor + self.batch_size]).index(i)
for i in sorted(self.idx[
self.cursor:self.cursor + self.batch_size])
]]) for x in data_source
]
else:
if self.cursor + self.batch_size < self.num_data:
end_idx = self.cursor + self.batch_size
# get incomplete last batch
else:
end_idx = self.num_data
return self._getdata(data_source, self.cursor, end_idx)
pad = self.batch_size - self.num_data + self.cursor
return [
# np.ndarray or NDArray case
concatenate([x[1][self.cursor:], x[1][:pad]])
if isinstance(x[1], (np.ndarray, NDArray)) else
# h5py (only supports indices in increasing order)
concatenate([
array(x[1][sorted(self.idx[self.cursor:])][[
list(self.idx[self.cursor:]).index(i)
for i in sorted(self.idx[self.cursor:])
]]),
array(x[1][sorted(self.idx[:pad])][[
list(self.idx[:pad]).index(i)
for i in sorted(self.idx[:pad])
]])
]) for x in data_source
]

def getdata(self):
"""Get data."""
return self._batchify(self.data)
return self._getdata(self.data)

def getlabel(self):
"""Get label."""
return self._batchify(self.label)
return self._getdata(self.label)

def getpad(self):
"""Get pad value of DataBatch."""
if self.last_batch_handle == 'pad' and \
self.cursor + self.batch_size > self.num_data:
return self.cursor + self.batch_size - self.num_data
# check the first batch
elif self.last_batch_handle == 'roll_over' and \
-self.batch_size < self.cursor < 0:
return -self.cursor
else:
return 0

def _shuffle_data(self):
"""Shuffle the data."""
# shuffle index
np.random.shuffle(self.idx)
# get the data by corresponding index
self.data = getdata_by_idx(self.data, self.idx)
self.label = getdata_by_idx(self.label, self.idx)

class MXDataIter(DataIter):
"""A python wrapper a C++ data iterator.
Expand All @@ -779,7 +773,7 @@ class MXDataIter(DataIter):
underlying C++ data iterators.

Usually you don't need to interact with `MXDataIter` directly unless you are
implementing your own data iterators in C+ +. To do that, please refer to
implementing your own data iterators in C++. To do that, please refer to
examples under the `src/io` folder.

Parameters
Expand Down
Loading