Skip to content

Commit

Permalink
follow comments
Browse files Browse the repository at this point in the history
  • Loading branch information
qingqing01 committed Nov 8, 2016
1 parent 2641348 commit 6d187f9
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 19 deletions.
2 changes: 2 additions & 0 deletions demo/image_classification/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ plot.png
train.log
image_provider_copy_1.py
*pyc
train.list
test.list
33 changes: 18 additions & 15 deletions demo/image_classification/image_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,18 +66,21 @@ def processData(settings, file_list):
file_list: the batch file list.
"""
with open(file_list, 'r') as fdata:
for file_name in fdata:
data = cPickle.load(io.open(file_name.strip(), 'rb'))
indexes = list(range(len(data['images'])))
if settings.is_train:
random.shuffle(indexes)
for i in indexes:
if settings.use_jpeg == 1:
img = image_util.decode_jpeg(data['images'][i])
else:
img = data['images'][i]
img_feat = image_util.preprocess_img(img, settings.img_mean,
settings.img_size, settings.is_train,
settings.color)
label = data['labels'][i]
yield img_feat.astype('float32'), int(label)
lines = [line.strip() for line in fdata]
random.shuffle(lines)
for file_name in lines:
with io.open(file_name.strip(), 'rb') as file:
data = cPickle.load(file)
indexes = list(range(len(data['images'])))
if settings.is_train:
random.shuffle(indexes)
for i in indexes:
if settings.use_jpeg == 1:
img = image_util.decode_jpeg(data['images'][i])
else:
img = data['images'][i]
img_feat = image_util.preprocess_img(img, settings.img_mean,
settings.img_size, settings.is_train,
settings.color)
label = data['labels'][i]
yield img_feat.astype('float32'), int(label)
2 changes: 2 additions & 0 deletions demo/image_classification/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def option_parser():
data_creator = ImageClassificationDatasetCreater(data_dir,
processed_image_size,
color)
data_creator.train_list_name = "train.txt"
data_creator.test_list_name = "test.txt"
data_creator.num_per_batch = 1000
data_creator.overwrite = True
data_creator.create_batches()
4 changes: 2 additions & 2 deletions demo/image_classification/preprocess.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ data_dir=./data/cifar-out

python preprocess.py -i $data_dir -s 32 -c 1

echo "data/cifar-out/batches/train.list" > trn.list
echo "data/cifar-out/batches/test.list" > tst.list
echo "data/cifar-out/batches/train.txt" > train.list
echo "data/cifar-out/batches/test.txt" > test.list
4 changes: 2 additions & 2 deletions demo/image_classification/vgg_16_cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
'img_size': 32,'num_classes': 10,
'use_jpeg': 1,'color': "color"}

define_py_data_sources2(train_list="trn.list",
test_list="tst.list",
define_py_data_sources2(train_list="train.list",
test_list="train.list",
module='image_provider',
obj='processData',
args=args)
Expand Down

0 comments on commit 6d187f9

Please sign in to comment.