-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix xmap_readers and refine flowers dataset #2631
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,18 +13,18 @@ | |
# limitations under the License. | ||
""" | ||
This module will download dataset from | ||
http://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html | ||
http://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html | ||
and parse train/test set intopaddle reader creators. | ||
|
||
This set contains images of flowers belonging to 102 different categories. | ||
This set contains images of flowers belonging to 102 different categories. | ||
The images were acquired by searching the web and taking pictures. There are a | ||
minimum of 40 images for each category. | ||
|
||
The database was used in: | ||
|
||
Nilsback, M-E. and Zisserman, A. Automated flower classification over a large | ||
number of classes.Proceedings of the Indian Conference on Computer Vision, | ||
Graphics and Image Processing (2008) | ||
number of classes.Proceedings of the Indian Conference on Computer Vision, | ||
Graphics and Image Processing (2008) | ||
http://www.robots.ox.ac.uk/~vgg/publications/papers/nilsback08.{pdf,ps.gz}. | ||
|
||
""" | ||
|
@@ -34,9 +34,9 @@ | |
import tarfile | ||
import scipy.io as scio | ||
from paddle.v2.image import * | ||
from paddle.v2.reader import * | ||
import os | ||
import numpy as np | ||
import paddle.v2 as paddle | ||
from multiprocessing import cpu_count | ||
__all__ = ['train', 'test', 'valid'] | ||
|
||
|
@@ -53,8 +53,8 @@ def default_mapper(sample): | |
map image bytes data to type needed by model input layer | ||
''' | ||
img, label = sample | ||
img = paddle.image.load_image_bytes(img) | ||
img = paddle.image.simple_transform(img, 256, 224, True) | ||
img = load_image_bytes(img) | ||
img = simple_transform(img, 256, 224, True) | ||
return img.flatten().astype('float32'), label | ||
|
||
|
||
|
@@ -63,22 +63,23 @@ def reader_creator(data_file, | |
setid_file, | ||
dataset_name, | ||
mapper=default_mapper, | ||
buffered_size=1024): | ||
buffered_size=1024, | ||
useXmap=True): | ||
''' | ||
1. read images from tar file and | ||
1. read images from tar file and | ||
merge images into batch files in 102flowers.tgz_batch/ | ||
2. get a reader to read sample from batch file | ||
:param data_file: downloaded data file | ||
|
||
:param data_file: downloaded data file | ||
:type data_file: string | ||
:param label_file: downloaded label file | ||
:param label_file: downloaded label file | ||
:type label_file: string | ||
:param setid_file: downloaded setid file containing information | ||
about how to split dataset | ||
:type setid_file: string | ||
:param dataset_name: data set name (tstid|trnid|valid) | ||
:type dataset_name: string | ||
:param mapper: a function to map image bytes data to type | ||
:param mapper: a function to map image bytes data to type | ||
needed by model input layer | ||
:type mapper: callable | ||
:param buffered_size: the size of buffer used to process images | ||
|
@@ -105,15 +106,17 @@ def reader(): | |
for sample, label in itertools.izip(data, batch['label']): | ||
yield sample, int(label) | ||
|
||
return paddle.reader.xmap_readers(mapper, reader, | ||
cpu_count(), buffered_size) | ||
if useXmap: | ||
return xmap_readers(mapper, reader, cpu_count(), buffered_size) | ||
else: | ||
return map_readers(mapper, reader) | ||
|
||
|
||
def train(mapper=default_mapper, buffered_size=1024): | ||
def train(mapper=default_mapper, buffered_size=1024, useXmap=True): | ||
''' | ||
Create flowers training set reader. | ||
It returns a reader, each sample in the reader is | ||
image pixels in [0, 1] and label in [1, 102] | ||
Create flowers training set reader. | ||
It returns a reader, each sample in the reader is | ||
image pixels in [0, 1] and label in [1, 102] | ||
translated from original color image by steps: | ||
1. resize to 256*256 | ||
2. random crop to 224*224 | ||
|
@@ -128,15 +131,15 @@ def train(mapper=default_mapper, buffered_size=1024): | |
return reader_creator( | ||
download(DATA_URL, 'flowers', DATA_MD5), | ||
download(LABEL_URL, 'flowers', LABEL_MD5), | ||
download(SETID_URL, 'flowers', SETID_MD5), 'trnid', mapper, | ||
buffered_size) | ||
download(SETID_URL, 'flowers', SETID_MD5), 'tstid', mapper, | ||
buffered_size, useXmap) | ||
|
||
|
||
def test(mapper=default_mapper, buffered_size=1024): | ||
def test(mapper=default_mapper, buffered_size=1024, useXmap=True): | ||
''' | ||
Create flowers test set reader. | ||
It returns a reader, each sample in the reader is | ||
image pixels in [0, 1] and label in [1, 102] | ||
Create flowers test set reader. | ||
It returns a reader, each sample in the reader is | ||
image pixels in [0, 1] and label in [1, 102] | ||
translated from original color image by steps: | ||
1. resize to 256*256 | ||
2. random crop to 224*224 | ||
|
@@ -151,15 +154,15 @@ def test(mapper=default_mapper, buffered_size=1024): | |
return reader_creator( | ||
download(DATA_URL, 'flowers', DATA_MD5), | ||
download(LABEL_URL, 'flowers', LABEL_MD5), | ||
download(SETID_URL, 'flowers', SETID_MD5), 'tstid', mapper, | ||
buffered_size) | ||
download(SETID_URL, 'flowers', SETID_MD5), 'trnid', mapper, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's better to add some comments why change the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok. |
||
buffered_size, useXmap) | ||
|
||
|
||
def valid(mapper=default_mapper, buffered_size=1024): | ||
def valid(mapper=default_mapper, buffered_size=1024, useXmap=True): | ||
''' | ||
Create flowers validation set reader. | ||
It returns a reader, each sample in the reader is | ||
image pixels in [0, 1] and label in [1, 102] | ||
Create flowers validation set reader. | ||
It returns a reader, each sample in the reader is | ||
image pixels in [0, 1] and label in [1, 102] | ||
translated from original color image by steps: | ||
1. resize to 256*256 | ||
2. random crop to 224*224 | ||
|
@@ -175,7 +178,7 @@ def valid(mapper=default_mapper, buffered_size=1024): | |
download(DATA_URL, 'flowers', DATA_MD5), | ||
download(LABEL_URL, 'flowers', LABEL_MD5), | ||
download(SETID_URL, 'flowers', SETID_MD5), 'valid', mapper, | ||
buffered_size) | ||
buffered_size, useXmap) | ||
|
||
|
||
def fetch(): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -166,12 +166,12 @@ def buffered(reader, size): | |
The buffered data reader will read and save data entries into a | ||
buffer. Reading from the buffered data reader will proceed as long | ||
as the buffer is not empty. | ||
|
||
:param reader: the data reader to read from. | ||
:type reader: callable | ||
:param size: max buffer size. | ||
:type size: int | ||
|
||
:returns: the buffered data reader. | ||
""" | ||
|
||
|
@@ -238,7 +238,7 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False): | |
:type mapper: callable | ||
:param reader: the data reader to read from | ||
:type reader: callable | ||
:param process_num: process number to handle original sample | ||
:param process_num: process number to handle original sample | ||
:type process_num: int | ||
:param buffer_size: max buffer size | ||
:type buffer_size: int | ||
|
@@ -248,9 +248,6 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False): | |
:rtype: callable | ||
""" | ||
end = XmapEndSignal() | ||
in_queue = Queue(buffer_size) | ||
out_queue = Queue(buffer_size) | ||
out_order = [0] | ||
|
||
# define a worker to read samples from reader to in_queue | ||
def read_worker(reader, in_queue): | ||
|
@@ -266,12 +263,6 @@ def order_read_worker(reader, in_queue): | |
in_order += 1 | ||
in_queue.put(end) | ||
|
||
# start a read worker in a thread | ||
target = order_read_worker if order else read_worker | ||
t = Thread(target=target, args=(reader, in_queue)) | ||
t.daemon = True | ||
t.start() | ||
|
||
# define a worker to handle samples from in_queue by mapper | ||
# and put mapped samples into out_queue | ||
def handle_worker(in_queue, out_queue, mapper): | ||
|
@@ -298,19 +289,27 @@ def order_handle_worker(in_queue, out_queue, mapper, out_order): | |
in_queue.put(end) | ||
out_queue.put(end) | ||
|
||
# start several handle_workers | ||
target = order_handle_worker if order else handle_worker | ||
args = (in_queue, out_queue, mapper, out_order) if order else ( | ||
in_queue, out_queue, mapper) | ||
workers = [] | ||
for i in xrange(process_num): | ||
worker = Thread(target=target, args=args) | ||
worker.daemon = True | ||
workers.append(worker) | ||
for w in workers: | ||
w.start() | ||
|
||
def xreader(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm unfamiliar with this part, should we put xreader in another file? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. get it. LGTM. |
||
in_queue = Queue(buffer_size) | ||
out_queue = Queue(buffer_size) | ||
out_order = [0] | ||
# start a read worker in a thread | ||
target = order_read_worker if order else read_worker | ||
t = Thread(target=target, args=(reader, in_queue)) | ||
t.daemon = True | ||
t.start() | ||
# start several handle_workers | ||
target = order_handle_worker if order else handle_worker | ||
args = (in_queue, out_queue, mapper, out_order) if order else ( | ||
in_queue, out_queue, mapper) | ||
workers = [] | ||
for i in xrange(process_num): | ||
worker = Thread(target=target, args=args) | ||
worker.daemon = True | ||
workers.append(worker) | ||
for w in workers: | ||
w.start() | ||
|
||
sample = out_queue.get() | ||
while not isinstance(sample, XmapEndSignal): | ||
yield sample | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
useXmap -> use_xmap
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx.