-
Notifications
You must be signed in to change notification settings - Fork 5
TfRecord的一些坑
DingfengShi edited this page Feb 26, 2018
·
6 revisions
- Int64List
- FloatList
- BytesList
这处理的是标量,也就是一维列表(所以要注意Feature构造函数里形参value一定要赋值成一维列表list)。放入的数据类型是ndarray。 由于序列化以后只能存储一维数据,所以如果要放入张量,必须把张量flatten以后存入(图片tobytes以后也是一维的)。如果形状是不固定的,这时候就可以用另外一个Int64List去存储一个"shape"的list,用于复原数据形状。
writer = tf.python_io.TFRecordWriter("train.tfrecords")
for index, name in enumerate(classes):
class_path = cwd + name + "/"
for img_name in os.listdir(class_path):
img_path = class_path + img_name
img = Image.open(img_path)
img = img.resize((224, 224))
img_raw = img.tobytes() #将图片转化为原生bytes
example = tf.train.Example(features=tf.train.Features(feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
}))
writer.write(example.SerializeToString()) #序列化为字符串
writer.close()
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue) #返回文件名和文件
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw' : tf.FixedLenFeature([], tf.string),
})
img = tf.decode_raw(features['img_raw'], tf.uint8)
img = tf.reshape(img, [224, 224, 3])
img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
label = tf.cast(features['label'], tf.int32)
- PIL库中Image类的size和tensorflow里的shape长宽是颠倒的,在存储shape的时候要注意
- 归一化前要把类型cast到float,否则会计算错误
- 要可视化Image类,得先把图片张量cast到uint8,直接用float32会显示异常