From 6d7e97b56d7631be8864af36ce43c82d24e84b1a Mon Sep 17 00:00:00 2001 From: Jake Lee Date: Mon, 10 Sep 2018 22:15:00 -0700 Subject: [PATCH] Change the way NDArrayIter handle the last batch (#12285) * 1. move the shuffle to the reset 2. modify the roll_over behavior accordingly * refactor the concat part * refactor the code * implement unit test for last_batch_handle * refactor the getdata part * add docstring and refine the code according to linter * 1. add test case for NDArrayIter_h5py 2. refactor the implementation * update contributions doc * fix wording * update doc for roll_over * 1. add test for second iteration of roll_over 2. add shuffle test case * fix some wording and refine the variables naming * move utility function to new file * move utility function to io_utils.py * change shuffle function name to avoid redefining name * make io as a module * rename the utility functions * disable wildcard-import --- CONTRIBUTORS.md | 1 + python/mxnet/io/__init__.py | 29 ++++ python/mxnet/{ => io}/io.py | 280 ++++++++++++++++--------------- python/mxnet/io/utils.py | 86 ++++++++++ tests/python/unittest/test_io.py | 122 ++++++++------ 5 files changed, 328 insertions(+), 190 deletions(-) create mode 100644 python/mxnet/io/__init__.py rename python/mxnet/{ => io}/io.py (82%) create mode 100644 python/mxnet/io/utils.py diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 8d8aeaca73e4..1c005d57c4a6 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -178,3 +178,4 @@ List of Contributors * [Aaron Markham](https://github.com/aaronmarkham) * [Sam Skalicky](https://github.com/samskalicky) * [Per Goncalves da Silva](https://github.com/perdasilva) +* [Cheng-Che Lee](https://github.com/stu1130) diff --git a/python/mxnet/io/__init__.py b/python/mxnet/io/__init__.py new file mode 100644 index 000000000000..5c5e2e68d84a --- /dev/null +++ b/python/mxnet/io/__init__.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +""" Data iterators for common data formats and utility functions.""" +from __future__ import absolute_import + +from . import io +from .io import * + +from . import utils +from .utils import * diff --git a/python/mxnet/io.py b/python/mxnet/io/io.py similarity index 82% rename from python/mxnet/io.py rename to python/mxnet/io/io.py index 884e9294741a..2ae3e70045fb 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io/io.py @@ -17,30 +17,26 @@ """Data iterators for common data formats.""" from __future__ import absolute_import -from collections import OrderedDict, namedtuple +from collections import namedtuple import sys import ctypes import logging import threading -try: - import h5py -except ImportError: - h5py = None import numpy as np -from .base import _LIB -from .base import c_str_array, mx_uint, py_str -from .base import DataIterHandle, NDArrayHandle -from .base import mx_real_t -from .base import check_call, build_param_doc as _build_param_doc -from .ndarray import NDArray -from .ndarray.sparse import CSRNDArray -from .ndarray.sparse import array as sparse_array -from .ndarray import _ndarray_cls -from .ndarray import array -from .ndarray import concatenate -from .ndarray import arange -from .ndarray.random import shuffle as random_shuffle + +from ..base import _LIB +from ..base import c_str_array, mx_uint, py_str +from ..base import DataIterHandle, NDArrayHandle +from ..base import mx_real_t +from ..base import check_call, build_param_doc as _build_param_doc +from ..ndarray import NDArray +from ..ndarray.sparse import CSRNDArray +from ..ndarray import _ndarray_cls +from ..ndarray import array +from ..ndarray import concat + +from .utils import init_data, has_instance, getdata_by_idx class DataDesc(namedtuple('DataDesc', ['name', 'shape'])): """DataDesc is used to store name, shape, type and layout @@ -489,59 +485,6 @@ def getindex(self): def getpad(self): return self.current_batch.pad -def _init_data(data, allow_empty, default_name): - """Convert data into canonical form.""" - assert (data is not None) or allow_empty - if data is None: - data = [] - - if isinstance(data, (np.ndarray, NDArray, h5py.Dataset) - if h5py else (np.ndarray, NDArray)): - data = [data] - if isinstance(data, list): - if not allow_empty: - assert(len(data) > 0) - if len(data) == 1: - data = OrderedDict([(default_name, data[0])]) # pylint: disable=redefined-variable-type - else: - data = OrderedDict( # pylint: disable=redefined-variable-type - [('_%d_%s' % (i, default_name), d) for i, d in enumerate(data)]) - if not isinstance(data, dict): - raise TypeError("Input must be NDArray, numpy.ndarray, h5py.Dataset " + \ - "a list of them or dict with them as values") - for k, v in data.items(): - if not isinstance(v, (NDArray, h5py.Dataset) if h5py else NDArray): - try: - data[k] = array(v) - except: - raise TypeError(("Invalid type '%s' for %s, " % (type(v), k)) + \ - "should be NDArray, numpy.ndarray or h5py.Dataset") - - return list(sorted(data.items())) - -def _has_instance(data, dtype): - """Return True if ``data`` has instance of ``dtype``. - This function is called after _init_data. - ``data`` is a list of (str, NDArray)""" - for item in data: - _, arr = item - if isinstance(arr, dtype): - return True - return False - -def _shuffle(data, idx): - """Shuffle the data.""" - shuffle_data = [] - - for k, v in data: - if (isinstance(v, h5py.Dataset) if h5py else False): - shuffle_data.append((k, v)) - elif isinstance(v, CSRNDArray): - shuffle_data.append((k, sparse_array(v.asscipy()[idx], v.context))) - else: - shuffle_data.append((k, array(v.asnumpy()[idx], v.context))) - - return shuffle_data class NDArrayIter(DataIter): """Returns an iterator for ``mx.nd.NDArray``, ``numpy.ndarray``, ``h5py.Dataset`` @@ -601,6 +544,22 @@ class NDArrayIter(DataIter): ... >>> batchidx # Remaining examples are discarded. So, 10/3 batches are created. 3 + >>> dataiter = mx.io.NDArrayIter(data, labels, 3, False, last_batch_handle='roll_over') + >>> batchidx = 0 + >>> for batch in dataiter: + ... batchidx += 1 + ... + >>> batchidx # Remaining examples are rolled over to the next iteration. + 3 + >>> dataiter.reset() + >>> dataiter.next().data[0].asnumpy() + [[[ 36. 37.] + [ 38. 39.]] + [[ 0. 1.] + [ 2. 3.]] + [[ 4. 5.] + [ 6. 7.]]] + (3L, 2L, 2L) `NDArrayIter` also supports multiple input and labels. @@ -633,8 +592,11 @@ class NDArrayIter(DataIter): Only supported if no h5py.Dataset inputs are used. last_batch_handle : str, optional How to handle the last batch. This parameter can be 'pad', 'discard' or - 'roll_over'. 'roll_over' is intended for training and can cause problems - if used for prediction. + 'roll_over'. + If 'pad', the last batch will be padded with data starting from the begining + If 'discard', the last batch will be discarded + If 'roll_over', the remaining elements will be rolled over to the next iteration and + note that it is intended for training and can cause problems if used for prediction. data_name : str, optional The data name. label_name : str, optional @@ -645,36 +607,28 @@ def __init__(self, data, label=None, batch_size=1, shuffle=False, label_name='softmax_label'): super(NDArrayIter, self).__init__(batch_size) - self.data = _init_data(data, allow_empty=False, default_name=data_name) - self.label = _init_data(label, allow_empty=True, default_name=label_name) + self.data = init_data(data, allow_empty=False, default_name=data_name) + self.label = init_data(label, allow_empty=True, default_name=label_name) - if ((_has_instance(self.data, CSRNDArray) or _has_instance(self.label, CSRNDArray)) and + if ((has_instance(self.data, CSRNDArray) or has_instance(self.label, CSRNDArray)) and (last_batch_handle != 'discard')): raise NotImplementedError("`NDArrayIter` only supports ``CSRNDArray``" \ " with `last_batch_handle` set to `discard`.") - # shuffle data - if shuffle: - tmp_idx = arange(self.data[0][1].shape[0], dtype=np.int32) - self.idx = random_shuffle(tmp_idx, out=tmp_idx).asnumpy() - self.data = _shuffle(self.data, self.idx) - self.label = _shuffle(self.label, self.idx) - else: - self.idx = np.arange(self.data[0][1].shape[0]) - - # batching - if last_batch_handle == 'discard': - new_n = self.data[0][1].shape[0] - self.data[0][1].shape[0] % batch_size - self.idx = self.idx[:new_n] + self.idx = np.arange(self.data[0][1].shape[0]) + self.shuffle = shuffle + self.last_batch_handle = last_batch_handle + self.batch_size = batch_size + self.cursor = -self.batch_size + self.num_data = self.idx.shape[0] + # shuffle + self.reset() self.data_list = [x[1] for x in self.data] + [x[1] for x in self.label] self.num_source = len(self.data_list) - self.num_data = self.idx.shape[0] - assert self.num_data >= batch_size, \ - "batch_size needs to be smaller than data size." - self.cursor = -batch_size - self.batch_size = batch_size - self.last_batch_handle = last_batch_handle + # used for 'roll_over' + self._cache_data = None + self._cache_label = None @property def provide_data(self): @@ -694,74 +648,126 @@ def provide_label(self): def hard_reset(self): """Ignore roll over data and set to start.""" + if self.shuffle: + self._shuffle_data() self.cursor = -self.batch_size + self._cache_data = None + self._cache_label = None def reset(self): - if self.last_batch_handle == 'roll_over' and self.cursor > self.num_data: - self.cursor = -self.batch_size + (self.cursor%self.num_data)%self.batch_size + """Resets the iterator to the beginning of the data.""" + if self.shuffle: + self._shuffle_data() + # the range below indicate the last batch + if self.last_batch_handle == 'roll_over' and \ + self.num_data - self.batch_size < self.cursor < self.num_data: + # (self.cursor - self.num_data) represents the data we have for the last batch + self.cursor = self.cursor - self.num_data - self.batch_size else: self.cursor = -self.batch_size def iter_next(self): + """Increments the coursor by batch_size for next batch + and check current cursor if it exceed the number of data points.""" self.cursor += self.batch_size return self.cursor < self.num_data def next(self): - if self.iter_next(): - return DataBatch(data=self.getdata(), label=self.getlabel(), \ - pad=self.getpad(), index=None) - else: + """Returns the next batch of data.""" + if not self.iter_next(): + raise StopIteration + data = self.getdata() + label = self.getlabel() + # iter should stop when last batch is not complete + if data[0].shape[0] != self.batch_size: + # in this case, cache it for next epoch + self._cache_data = data + self._cache_label = label raise StopIteration + return DataBatch(data=data, label=label, \ + pad=self.getpad(), index=None) + + def _getdata(self, data_source, start=None, end=None): + """Load data from underlying arrays.""" + assert start is not None or end is not None, 'should at least specify start or end' + start = start if start is not None else 0 + end = end if end is not None else data_source[0][1].shape[0] + s = slice(start, end) + return [ + x[1][s] + if isinstance(x[1], (np.ndarray, NDArray)) else + # h5py (only supports indices in increasing order) + array(x[1][sorted(self.idx[s])][[ + list(self.idx[s]).index(i) + for i in sorted(self.idx[s]) + ]]) for x in data_source + ] - def _getdata(self, data_source): + def _concat(self, first_data, second_data): + """Helper function to concat two NDArrays.""" + return [ + concat(first_data[0], second_data[0], dim=0) + ] + + def _batchify(self, data_source): """Load data from underlying arrays, internal use only.""" - assert(self.cursor < self.num_data), "DataIter needs reset." - if self.cursor + self.batch_size <= self.num_data: - return [ - # np.ndarray or NDArray case - x[1][self.cursor:self.cursor + self.batch_size] - if isinstance(x[1], (np.ndarray, NDArray)) else - # h5py (only supports indices in increasing order) - array(x[1][sorted(self.idx[ - self.cursor:self.cursor + self.batch_size])][[ - list(self.idx[self.cursor: - self.cursor + self.batch_size]).index(i) - for i in sorted(self.idx[ - self.cursor:self.cursor + self.batch_size]) - ]]) for x in data_source - ] - else: + assert self.cursor < self.num_data, 'DataIter needs reset.' + # first batch of next epoch with 'roll_over' + if self.last_batch_handle == 'roll_over' and \ + -self.batch_size < self.cursor < 0: + assert self._cache_data is not None or self._cache_label is not None, \ + 'next epoch should have cached data' + cache_data = self._cache_data if self._cache_data is not None else self._cache_label + second_data = self._getdata( + data_source, end=self.cursor + self.batch_size) + if self._cache_data is not None: + self._cache_data = None + else: + self._cache_label = None + return self._concat(cache_data, second_data) + # last batch with 'pad' + elif self.last_batch_handle == 'pad' and \ + self.cursor + self.batch_size > self.num_data: pad = self.batch_size - self.num_data + self.cursor - return [ - # np.ndarray or NDArray case - concatenate([x[1][self.cursor:], x[1][:pad]]) - if isinstance(x[1], (np.ndarray, NDArray)) else - # h5py (only supports indices in increasing order) - concatenate([ - array(x[1][sorted(self.idx[self.cursor:])][[ - list(self.idx[self.cursor:]).index(i) - for i in sorted(self.idx[self.cursor:]) - ]]), - array(x[1][sorted(self.idx[:pad])][[ - list(self.idx[:pad]).index(i) - for i in sorted(self.idx[:pad]) - ]]) - ]) for x in data_source - ] + first_data = self._getdata(data_source, start=self.cursor) + second_data = self._getdata(data_source, end=pad) + return self._concat(first_data, second_data) + # normal case + else: + if self.cursor + self.batch_size < self.num_data: + end_idx = self.cursor + self.batch_size + # get incomplete last batch + else: + end_idx = self.num_data + return self._getdata(data_source, self.cursor, end_idx) def getdata(self): - return self._getdata(self.data) + """Get data.""" + return self._batchify(self.data) def getlabel(self): - return self._getdata(self.label) + """Get label.""" + return self._batchify(self.label) def getpad(self): + """Get pad value of DataBatch.""" if self.last_batch_handle == 'pad' and \ self.cursor + self.batch_size > self.num_data: return self.cursor + self.batch_size - self.num_data + # check the first batch + elif self.last_batch_handle == 'roll_over' and \ + -self.batch_size < self.cursor < 0: + return -self.cursor else: return 0 + def _shuffle_data(self): + """Shuffle the data.""" + # shuffle index + np.random.shuffle(self.idx) + # get the data by corresponding index + self.data = getdata_by_idx(self.data, self.idx) + self.label = getdata_by_idx(self.label, self.idx) class MXDataIter(DataIter): """A python wrapper a C++ data iterator. @@ -773,7 +779,7 @@ class MXDataIter(DataIter): underlying C++ data iterators. Usually you don't need to interact with `MXDataIter` directly unless you are - implementing your own data iterators in C++. To do that, please refer to + implementing your own data iterators in C+ +. To do that, please refer to examples under the `src/io` folder. Parameters diff --git a/python/mxnet/io/utils.py b/python/mxnet/io/utils.py new file mode 100644 index 000000000000..872e6410d7de --- /dev/null +++ b/python/mxnet/io/utils.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""utility functions for io.py""" +from collections import OrderedDict + +import numpy as np +try: + import h5py +except ImportError: + h5py = None + +from ..ndarray.sparse import CSRNDArray +from ..ndarray.sparse import array as sparse_array +from ..ndarray import NDArray +from ..ndarray import array + +def init_data(data, allow_empty, default_name): + """Convert data into canonical form.""" + assert (data is not None) or allow_empty + if data is None: + data = [] + + if isinstance(data, (np.ndarray, NDArray, h5py.Dataset) + if h5py else (np.ndarray, NDArray)): + data = [data] + if isinstance(data, list): + if not allow_empty: + assert(len(data) > 0) + if len(data) == 1: + data = OrderedDict([(default_name, data[0])]) # pylint: disable=redefined-variable-type + else: + data = OrderedDict( # pylint: disable=redefined-variable-type + [('_%d_%s' % (i, default_name), d) for i, d in enumerate(data)]) + if not isinstance(data, dict): + raise TypeError("Input must be NDArray, numpy.ndarray, h5py.Dataset " + + "a list of them or dict with them as values") + for k, v in data.items(): + if not isinstance(v, (NDArray, h5py.Dataset) if h5py else NDArray): + try: + data[k] = array(v) + except: + raise TypeError(("Invalid type '%s' for %s, " % (type(v), k)) + + "should be NDArray, numpy.ndarray or h5py.Dataset") + + return list(sorted(data.items())) + + +def has_instance(data, dtype): + """Return True if ``data`` has instance of ``dtype``. + This function is called after _init_data. + ``data`` is a list of (str, NDArray)""" + for item in data: + _, arr = item + if isinstance(arr, dtype): + return True + return False + + +def getdata_by_idx(data, idx): + """Shuffle the data.""" + shuffle_data = [] + + for k, v in data: + if (isinstance(v, h5py.Dataset) if h5py else False): + shuffle_data.append((k, v)) + elif isinstance(v, CSRNDArray): + shuffle_data.append((k, sparse_array(v.asscipy()[idx], v.context))) + else: + shuffle_data.append((k, array(v.asnumpy()[idx], v.context))) + + return shuffle_data diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py index 4dfa69cc1050..ae686261b818 100644 --- a/tests/python/unittest/test_io.py +++ b/tests/python/unittest/test_io.py @@ -88,80 +88,88 @@ def test_Cifar10Rec(): assert(labelcount[i] == 5000) -def test_NDArrayIter(): +def _init_NDArrayIter_data(): data = np.ones([1000, 2, 2]) - label = np.ones([1000, 1]) + labels = np.ones([1000, 1]) for i in range(1000): data[i] = i / 100 - label[i] = i / 100 - dataiter = mx.io.NDArrayIter( - data, label, 128, True, last_batch_handle='pad') - batchidx = 0 + labels[i] = i / 100 + return data, labels + + +def _test_last_batch_handle(data, labels): + # Test the three parameters 'pad', 'discard', 'roll_over' + last_batch_handle_list = ['pad', 'discard' , 'roll_over'] + labelcount_list = [(124, 100), (100, 96), (100, 96)] + batch_count_list = [8, 7, 7] + + for idx in range(len(last_batch_handle_list)): + dataiter = mx.io.NDArrayIter( + data, labels, 128, False, last_batch_handle=last_batch_handle_list[idx]) + batch_count = 0 + labelcount = [0 for i in range(10)] + for batch in dataiter: + label = batch.label[0].asnumpy().flatten() + # check data if it matches corresponding labels + assert((batch.data[0].asnumpy()[:, 0, 0] == label).all()), last_batch_handle_list[idx] + for i in range(label.shape[0]): + labelcount[int(label[i])] += 1 + # keep the last batch of 'pad' to be used later + # to test first batch of roll_over in second iteration + batch_count += 1 + if last_batch_handle_list[idx] == 'pad' and \ + batch_count == 8: + cache = batch.data[0].asnumpy() + # check if batchifying functionality work properly + assert labelcount[0] == labelcount_list[idx][0], last_batch_handle_list[idx] + assert labelcount[8] == labelcount_list[idx][1], last_batch_handle_list[idx] + assert batch_count == batch_count_list[idx] + # roll_over option + dataiter.reset() + assert np.array_equal(dataiter.next().data[0].asnumpy(), cache) + + +def _test_shuffle(data, labels): + dataiter = mx.io.NDArrayIter(data, labels, 1, False) + batch_list = [] for batch in dataiter: - batchidx += 1 - assert(batchidx == 8) - dataiter = mx.io.NDArrayIter( - data, label, 128, False, last_batch_handle='pad') - batchidx = 0 - labelcount = [0 for i in range(10)] + # cache the original data + batch_list.append(batch.data[0].asnumpy()) + dataiter = mx.io.NDArrayIter(data, labels, 1, True) + idx_list = dataiter.idx + i = 0 for batch in dataiter: - label = batch.label[0].asnumpy().flatten() - assert((batch.data[0].asnumpy()[:, 0, 0] == label).all()) - for i in range(label.shape[0]): - labelcount[int(label[i])] += 1 + # check if each data point have been shuffled to corresponding positions + assert np.array_equal(batch.data[0].asnumpy(), batch_list[idx_list[i]]) + i += 1 - for i in range(10): - if i == 0: - assert(labelcount[i] == 124) - else: - assert(labelcount[i] == 100) + +def test_NDArrayIter(): + data, labels = _init_NDArrayIter_data() + _test_last_batch_handle(data, labels) + _test_shuffle(data, labels) def test_NDArrayIter_h5py(): if not h5py: return - data = np.ones([1000, 2, 2]) - label = np.ones([1000, 1]) - for i in range(1000): - data[i] = i / 100 - label[i] = i / 100 + data, labels = _init_NDArrayIter_data() try: - os.remove("ndarraytest.h5") + os.remove('ndarraytest.h5') except OSError: pass - with h5py.File("ndarraytest.h5") as f: - f.create_dataset("data", data=data) - f.create_dataset("label", data=label) - - dataiter = mx.io.NDArrayIter( - f["data"], f["label"], 128, True, last_batch_handle='pad') - batchidx = 0 - for batch in dataiter: - batchidx += 1 - assert(batchidx == 8) - - dataiter = mx.io.NDArrayIter( - f["data"], f["label"], 128, False, last_batch_handle='pad') - labelcount = [0 for i in range(10)] - for batch in dataiter: - label = batch.label[0].asnumpy().flatten() - assert((batch.data[0].asnumpy()[:, 0, 0] == label).all()) - for i in range(label.shape[0]): - labelcount[int(label[i])] += 1 + with h5py.File('ndarraytest.h5') as f: + f.create_dataset('data', data=data) + f.create_dataset('label', data=labels) + _test_last_batch_handle(f['data'], f['label']) try: os.remove("ndarraytest.h5") except OSError: pass - for i in range(10): - if i == 0: - assert(labelcount[i] == 124) - else: - assert(labelcount[i] == 100) - def test_NDArrayIter_csr(): # creating toy data @@ -182,12 +190,20 @@ def test_NDArrayIter_csr(): {'data': train_data}, dns, batch_size) except ImportError: pass + # scipy.sparse.csr_matrix with shuffle + num_batch = 0 + csr_iter = iter(mx.io.NDArrayIter({'data': train_data}, dns, batch_size, + shuffle=True, last_batch_handle='discard')) + for _ in csr_iter: + num_batch += 1 + + assert(num_batch == num_rows // batch_size) # CSRNDArray with shuffle csr_iter = iter(mx.io.NDArrayIter({'csr_data': csr, 'dns_data': dns}, dns, batch_size, shuffle=True, last_batch_handle='discard')) num_batch = 0 - for batch in csr_iter: + for _ in csr_iter: num_batch += 1 assert(num_batch == num_rows // batch_size)