彻底解决tfrecord读写问题

方案1

官方套娃方案存在的问题:

# 序列化
example = tf.train.Example(
    features=tf.train.Features(
        feature={
            'feature0': tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature0])),
            'feature1': tf.train.Feature(float_list=tf.train.FloatList(value=[feature1])),
            'feature2': tf.train.Feature(int64_list=tf.train.Int64List(value=[feature2])),
    }))
example.SerializeToString()

# 反序列化
feature_description = {
    'feature0': tf.io.FixedLenFeature([], tf.string),
    'feature1': tf.io.FixedLenFeature([], tf.float32),
    'feature2': tf.io.FixedLenFeature([], tf.int64),
}
tf.io.parse_single_example(parsed_example, feature_description)

优点:格式自由;缺点:在这个过程中,无法保存feature0的shape,所以解决方案是把shape也保存起来,反序列化之后先获得shape,再reshape,但是这种方法需要在用的时候再调整。

方案2

# 序列化
serialized_example = tf.io.serialize_tensor(tensor)
# 反序列化
tf.io.parse_tensor(serialized_example, dtype)

优点:此种方法能保持shape的将一个Tensor进行序列化和反序列化;缺点:存的时候得把x和y拼成一个tensor,如果x的shape比较复杂,则不容易把y拼起来。

方案3

结合两者特点:再进一步套娃

# 序列化
def serialize_example(feature0, feature1):
    feature = {
        'feature0': tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(feature0).numpy()])),
        'feature1': tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(feature1).numpy()])),
    }
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()
# 反序列化
feature_description = {
        'feature0': tf.io.FixedLenFeature([], tf.string, default_value=''),
        'feature1': tf.io.FixedLenFeature([], tf.string, default_value=''),
    }
example = tf.io.parse_single_example(serialized_example, feature_description)
feature0 = tf.io.parse_tensor(example['feature0'], dtype)
feature1 = tf.io.parse_tensor(example['feature1'], dtype)

这种方法既有格式自由的优点,还有保持数据shape的优点,那缺点肯定是有的,这种方案进一步套娃,时间和空间复杂度都上去了。 经测试,这种方案和方案2中的序列化Tensor方案对比:不读写文件的情况下,序列化之后再反序列化一个数据,时间复杂度差异位套娃方案比序列化方案高10倍以上。在预先写入序列化数据到tfrecord中,然后比较两者读出数据的时间,前者比后者所用高5倍左右。文件大小方面,前者比后者高约20%。但是因为拼接y的时候可能要和x对齐,需要把y补一些数据,所以实际上可能后者数据反而会更大。


本文章使用limfx的vscode插件快速发布