生活随笔
收集整理的這篇文章主要介紹了
TF2.0 TFRecord创建和读取
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
本人為在職研究生,希望能夠有志同道合的學習伙伴一起學習分享和交流,本人領域為光纖傳感和人工智能,希望可以一起學習。
微信公眾號:Deepthinkerr(文章末尾有圖)
文章目錄
- 1. TFRecord
- 1.1 數據寫入TFRecord
- 1.2 讀取TFRecord
- 2. 微信公眾號
1. TFRecord
這里不太建議看《Tensorflow2.0卷積神經網絡實戰》王曉華這本書,講的內容并沒有什么問題,但是代碼我嘗試了,很多報錯(報錯說函數沒有參數,不知道是不是我的tf版本問題),而且一些地方沒有講清楚,這里建議看rensorflow的官網教程,看了一遍整體還是蠻不錯的,一些函數和書上講的不一樣,這里建議看官方的。
Tensorflow官網TFRecord鏈接:https://tensorflow.google.cn/tutorials/load_data/tfrecord?hl=zh_cn
1.1 數據寫入TFRecord
將數據寫入TFrecord步驟較為固定,個人總結為以下幾個步驟:
- step1:將每個值轉換為包含三種兼容類型之一的 tf.train.Feature
- step2:創建一個從特征名稱字符串到第 1 步中生成的編碼特征值的映射(字典)
- step3:將第 2 步中生成的映射轉換為Features消息
- step4:創建example
- step5:寫入TFRecords
def _bytes_feature(value
):"""Returns a bytes_list from a string / byte."""if isinstance(value
, type(tf
.constant
(0))):value
= value
.numpy
() return tf
.train
.Feature
(bytes_list
=tf
.train
.BytesList
(value
=[value
]))def _float_feature(value
):"""Returns a float_list from a float / double."""return tf
.train
.Feature
(float_list
=tf
.train
.FloatList
(value
=[value
]))def _int64_feature(value
):"""Returns an int64_list from a bool / enum / int / uint."""return tf
.train
.Feature
(int64_list
=tf
.train
.Int64List
(value
=[value
]))def image_feature(image
, label
):image_shape
= image
.shapefeature
= {'height': _int64_feature
(image_shape
[0]), 'width': _int64_feature
(image_shape
[1]),'depth': _int64_feature
(image_shape
[2]),'label': _int64_feature
(label
),'image_raw': _bytes_feature
(image
.tobytes
()) } return tf
.train
.Features
(feature
=feature
)
data
= np
.random
.random
([100, 28, 28, 3])
label
= np
.random
.randint
(0, 10, 100)
record_file
= 'test.tfrecord'
with tf
.io
.TFRecordWriter
(record_file
) as writer
:for i
in range(100):tf_example
= tf
.train
.Example
(features
=image_feature
(data
[i
], label
[i
])) writer
.write
(tf_example
.SerializeToString
())
注意:step1使用的是tf.train.Feature函數,函數參數為bytes_list、float_list、int64_list三種,是tf.train.Feature接受的三種類型,step3使用的是tf.train.Features,有個s,這里需要注意,而且函數參數為feature
個人理解:tf.train.Feature和tf.train.Features區別,前者是對一個數據進行Feature消息處理,后者是對多個Feature組成的dict進行處理,所以多一個s。
1.2 讀取TFRecord
讀取TFRecord文件步驟也比較固定,主要是需要對讀取的byte數據進行解析,這里必須和創建TFRecord時相對應,步驟如下:
- step1:創建解析字典(和寫入相對應)
- step2:解析函數(將讀取的數據解析為寫入的數據)
- step3:讀取TFRecords文件
- step4:解析數據
image_feature_description
= {'height': tf
.io
.FixedLenFeature
([], tf
.int64
),'width': tf
.io
.FixedLenFeature
([], tf
.int64
),'depth': tf
.io
.FixedLenFeature
([], tf
.int64
),'label': tf
.io
.FixedLenFeature
([], tf
.int64
),'image_raw': tf
.io
.FixedLenFeature
([], tf
.string
),
}
def _parse_image_function(example_proto
):return tf
.io
.parse_single_example
(example_proto
, image_feature_description
) raw_image_dataset
= tf
.data
.TFRecordDataset
('test.tfrecord')
parsed_image_dataset
= raw_image_dataset
.map(_parse_image_function
)
讀取出來的圖片數據是string,也就是byte格式,如果是顯示的話直接通過IPython.display.display即可顯示圖片,如果要轉換為array需要進行數據類型轉換和reshape。
2. 微信公眾號
總結
以上是生活随笔為你收集整理的TF2.0 TFRecord创建和读取的全部內容,希望文章能夠幫你解決所遇到的問題。
如果覺得生活随笔網站內容還不錯,歡迎將生活随笔推薦給好友。