Skip to content

TfRecord的一些坑

DingfengShi edited this page Feb 26, 2018 · 6 revisions

一篇写的很好的博文

TfRecord有三种数据类型:

  • Int64List
  • FloatList
  • BytesList

这处理的是标量,也就是一维列表(所以要注意Feature构造函数里形参value一定要赋值成一维列表list)。放入的数据类型是ndarray。 由于序列化以后只能存储一维数据,所以如果要放入张量,必须把张量flatten以后存入(图片tobytes以后也是一维的)。如果形状是不固定的,这时候就可以用另外一个Int64List去存储一个"shape"的list,用于复原数据形状。

写入TFRecord例子

writer = tf.python_io.TFRecordWriter("train.tfrecords")
for index, name in enumerate(classes):
    class_path = cwd + name + "/"
    for img_name in os.listdir(class_path):
        img_path = class_path + img_name
            img = Image.open(img_path)
            img = img.resize((224, 224))
        img_raw = img.tobytes()              #将图片转化为原生bytes
        example = tf.train.Example(features=tf.train.Features(feature={
            "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
        }))
        writer.write(example.SerializeToString())  #序列化为字符串
writer.close()

读取TFRecord例子

reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)   #返回文件名和文件
features = tf.parse_single_example(serialized_example,
                                   features={
                                       'label': tf.FixedLenFeature([], tf.int64),
                                       'img_raw' : tf.FixedLenFeature([], tf.string),
                                   })
img = tf.decode_raw(features['img_raw'], tf.uint8)
img = tf.reshape(img, [224, 224, 3])
img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
label = tf.cast(features['label'], tf.int32)
  • PIL库中Image类的size和tensorflow里的shape长宽是颠倒的,在存储shape的时候要注意
  • 归一化前要把类型cast到float,否则会计算错误
  • 要可视化Image类,得先把图片张量cast到uint8,直接用float32会显示异常
Clone this wiki locally