diff --git a/demo/image_classification/.gitignore b/demo/image_classification/.gitignore index 76961dd1436f8..6a05b8f6632db 100644 --- a/demo/image_classification/.gitignore +++ b/demo/image_classification/.gitignore @@ -5,3 +5,5 @@ plot.png train.log image_provider_copy_1.py *pyc +train.list +test.list diff --git a/demo/image_classification/data/download_cifar.sh b/demo/image_classification/data/download_cifar.sh old mode 100644 new mode 100755 diff --git a/demo/image_classification/image_provider.py b/demo/image_classification/image_provider.py index 9e2f8b8949b39..305efbcdc6bb1 100644 --- a/demo/image_classification/image_provider.py +++ b/demo/image_classification/image_provider.py @@ -58,24 +58,29 @@ def hook(settings, img_size, mean_img_size, num_classes, color, meta, use_jpeg, settings.logger.info('DataProvider Initialization finished') -@provider(init_hook=hook) -def processData(settings, file_name): +@provider(init_hook=hook, min_pool_size=0) +def processData(settings, file_list): """ The main function for loading data. Load the batch, iterate all the images and labels in this batch. - file_name: the batch file name. + file_list: the batch file list. """ - data = cPickle.load(io.open(file_name, 'rb')) - indexes = list(range(len(data['images']))) - if settings.is_train: - random.shuffle(indexes) - for i in indexes: - if settings.use_jpeg == 1: - img = image_util.decode_jpeg(data['images'][i]) - else: - img = data['images'][i] - img_feat = image_util.preprocess_img(img, settings.img_mean, - settings.img_size, settings.is_train, - settings.color) - label = data['labels'][i] - yield img_feat.tolist(), int(label) + with open(file_list, 'r') as fdata: + lines = [line.strip() for line in fdata] + random.shuffle(lines) + for file_name in lines: + with io.open(file_name.strip(), 'rb') as file: + data = cPickle.load(file) + indexes = list(range(len(data['images']))) + if settings.is_train: + random.shuffle(indexes) + for i in indexes: + if settings.use_jpeg == 1: + img = image_util.decode_jpeg(data['images'][i]) + else: + img = data['images'][i] + img_feat = image_util.preprocess_img(img, settings.img_mean, + settings.img_size, settings.is_train, + settings.color) + label = data['labels'][i] + yield img_feat.astype('float32'), int(label) diff --git a/demo/image_classification/preprocess.py b/demo/image_classification/preprocess.py index 0286a5d7e9dc8..fe7ea19bf0277 100755 --- a/demo/image_classification/preprocess.py +++ b/demo/image_classification/preprocess.py @@ -35,6 +35,8 @@ def option_parser(): data_creator = ImageClassificationDatasetCreater(data_dir, processed_image_size, color) + data_creator.train_list_name = "train.txt" + data_creator.test_list_name = "test.txt" data_creator.num_per_batch = 1000 data_creator.overwrite = True data_creator.create_batches() diff --git a/demo/image_classification/preprocess.sh b/demo/image_classification/preprocess.sh index dfe3eb95d1ab8..e3e86ff10675c 100755 --- a/demo/image_classification/preprocess.sh +++ b/demo/image_classification/preprocess.sh @@ -17,3 +17,6 @@ set -e data_dir=./data/cifar-out python preprocess.py -i $data_dir -s 32 -c 1 + +echo "data/cifar-out/batches/train.txt" > train.list +echo "data/cifar-out/batches/test.txt" > test.list diff --git a/demo/image_classification/vgg_16_cifar.py b/demo/image_classification/vgg_16_cifar.py index e8b8af4bd313d..edd6988c48acd 100755 --- a/demo/image_classification/vgg_16_cifar.py +++ b/demo/image_classification/vgg_16_cifar.py @@ -25,8 +25,8 @@ 'img_size': 32,'num_classes': 10, 'use_jpeg': 1,'color': "color"} - define_py_data_sources2(train_list=data_dir+"train.list", - test_list=data_dir+'test.list', + define_py_data_sources2(train_list="train.list", + test_list="train.list", module='image_provider', obj='processData', args=args)