TFRecord 统一数据格式
生活随笔
收集整理的這篇文章主要介紹了
TFRecord 统一数据格式
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
1.將輸入數據保存為TFRecord格式
__author__ = 'ding' ''' 將數據保存為TFRecord格式 ''' import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data import numpy as np# 生成整數型的屬性 def _int64_feature(value):return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))# 生成字符串型的屬性 def _bytes_feature(value):return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))mnist = input_data.read_data_sets('./path/to/mnist/data', dtype=tf.uint8, one_hot=True) images = mnist.train.images # 訓練數據對應的正確答案,作為一個屬性保存在TFRecord中 labels = mnist.train.labels # 訓練數據的圖像分辨率,作為Example中的一個屬性 pixels = images.shape[1] num_examples = mnist.train.num_examples# 輸出文件的路徑 filename = './path/to/output.tfrecords' # 創建一個writer寫TFRecord文件 writer = tf.python_io.TFRecordWriter(filename) for index in range(num_examples):# 將圖像矩陣轉化成一個字符串image_raw = images[index].tostring()# 將一個樣例轉換成Example Protocol Buffer,并將所有信息寫入這個數據結構example = tf.train.Example(features=tf.train.Features(feature={'pixels': _int64_feature(pixels),'label': _int64_feature(np.argmax(labels[index])),'image_raw':_bytes_feature(image_raw)}))# 將一個Example 寫入TFRecord文件writer.write(example.SerializeToString())writer.close()在工程的/path/to目錄下生成一個output.tfrecord文件,這個文件就是輸入數據的TFRecord格式文件
此處注意 不要遺漏tf.train.Features后的s,為了不必要的錯誤,需要仔細核對(掉過坑,所以提醒。。)
2.讀取TFRecord文件中的數據
__author__ = 'ding' ''' 讀取TFRecord文件中的數據 ''' import tensorflow as tf# 創建一個reader來讀取TFRecord文件中的樣例 reader = tf.TFRecordReader()# 創建一個列隊來維護輸入文件列表 filename_queue = tf.train.string_input_producer(['./path/to/output.tfrecords']) # 從文件中讀取一個樣例,也可以使用read_up_to函數一次讀取多個樣例 _,serialized_example = reader.read(filename_queue)# 解析讀入的一個樣例,如果需要解析多個樣例,也可以使用parse_example函數 features = tf.parse_single_example(serialized_example,features={# 解析方法與保存方法應該一致,避免報錯# TensorFlow 有兩種屬性解析的方法,# tf.FixedLenGeature, 解析結果為一個Tensor# tf.VarLenFrature,解析結果為SparseTensor,用于稀疏處理'image_raw':tf.FixedLenFeature([],tf.string),'pixels':tf.FixedLenFeature([],tf.int64),'label':tf.FixedLenFeature([],tf.int64)})# tf.decode_raw 可以將字符串解析成圖像對應的像素數組 images = tf.decode_raw(features['image_raw'],tf.uint8) labels = tf.cast(features['label'],tf.int32) pixels = tf.cast(features['pixels'],tf.int32)sess = tf.Session() # 啟用多線程處理 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess,coord=coord)# 每次運行可以讀取TFRecord文件中的一個樣例,當所有樣例讀完之后,此樣例中程序會在重頭讀取 for i in range(10):images,labels,pixels = sess.run([images,labels,pixels]) has invalid type <class 'numpy.ndarray'>, must be a string or Tensor ***原因:變量命名重復了總結
以上是生活随笔為你收集整理的TFRecord 统一数据格式的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: PLC立体车库智能仿真 博途V15 3×
- 下一篇: java通过LocalDateTime获