【TensorFlow2.0】数据读取与使用方式
大家好,這是專欄《TensorFlow2.0》的第三篇文章,講述如何使用TensorFlow2.0讀取和使用自己的數據集。
如果您正在學習計算機視覺,無論你通過書籍還是視頻學習,大部分的教程講解的都是MNIST等已經為用戶打包封裝好的數據集,用戶只需要load_data即可實現數據導入。但是在我們平時使用時,無論您是做分類還是檢測或者分割任務,我們不可能每次都能找到打包好的數據集使用,大多數時候我們使用的都是自己的數據集,也就是我們需要從本地讀取文件。因此我們是很有必要學會數據預處理這個本領的。本篇文章,我們就聊聊如何使用TensorFlow2.0對自己的數據集進行處理。
作者&編輯 | 湯興旺
在TensorFlow2.0中,對數據處理的方法有很多種,下面我主要介紹兩種我自認為最好用的數據預處理的方法。
1 使用Keras API對數據進行預處理
1.1 數據集
本文用到的數據集是表情分類數據集,數據集有1000張圖片,包括500張微笑圖片,500張非微笑圖片。圖片預覽如下:
微笑圖片:
非微笑圖片:
數據集結構組織如下:
其中800張圖片用來訓練,200張用來測試,每個類別的樣本也是相同的。
1.2 數據預處理
我們知道,在將數據輸入神經網絡之前,需要將數據格式化為經過預處理的浮點數張量。現在我們看看數據預處理的步驟,如下圖:
這個步驟雖然看起來比較復雜,但在TensorFlow2.0的高級API Keras中有個比較好用的圖像處理的類ImageDataGenerator,它可以將本地圖像文件自動轉換為處理好的張量。
接下來我們通過代碼來解釋如何利用Keras來對數據預處理,完整代碼如下:
from tensorflow.keras.preprocessing.image import ImageDataGenerator
train_data_dir = r"D://Learning//tensorflow_2.0//smile//data//train"
img_width,img_height = 48,48
train_datagen = ImageDataGenerator(
在上面的代碼中,我們首先導入ImageDataGenerator,即下面代碼:
from tensorflow.keras.preprocessing.image import ImageDataGenerator
ImageDataGenerator是tensorflow.keras.preprocessing.image模塊中的圖片生成器,同時也可以使用它在batch中對數據進行增強,擴充數據集大小,從而增強模型的泛化能力。
ImageDataGenerator中有眾多的參數,如下:
tf.keras.preprocessing.image.ImageDataGenerator(
? ?featurewise_center=False,
? ?brightness_range, ? ?
? ?shear_range=0.,
具體含義如下:
featurewise_center:布爾值,使輸入數據集去中心化(均值為0)
samplewise_center:布爾值,使輸入數據的每個樣本均值為0。
featurewise_std_normalization:布爾值,將輸入除以數據集的標準差以完成標準化。
samplewise_std_normalization:布爾值,將輸入的每個樣本除以其自身的標準差。
zca_whitening:布爾值,對輸入數據施加ZCA白化。
rotation_range:整數,數據增強時圖片隨機轉動的角度。隨機選擇圖片的角度,是一個0~180的度數,取值為0~180。
width_shift_range:浮點數,圖片寬度的某個比例,數據增強時圖片隨機水平偏移的幅度。
height_shift_range:浮點數,圖片高度的某個比例,數據增強時圖片隨機豎直偏移的幅度。?
shear_range:浮點數,剪切強度(逆時針方向的剪切變換角度)。是用來進行剪切變換的程度。
zoom_range:浮點數或形如[lower,upper]的列表,隨機縮放的幅度,若為浮點數,則相當于[lower,upper] = [1 - zoom_range, 1+zoom_range]。用來進行隨機的放大。
channel_shift_range:浮點數,隨機通道偏移的幅度。
fill_mode:‘constant’,‘nearest’,‘reflect’或‘wrap’之一,當進行變換時超出邊界的點將根據本參數給定的方法進行處理。
cval:浮點數或整數,當fill_mode=constant時,指定要向超出邊界的點填充的值。
horizontal_flip:布爾值,進行隨機水平翻轉。隨機的對圖片進行水平翻轉,這個參數適用于水平翻轉不影響圖片語義的時候。
vertical_flip:布爾值,進行隨機豎直翻轉。
rescale: 值將在執行其他處理前乘到整個圖像上,我們的圖像在RGB通道都是0~255的整數,這樣的操作可能使圖像的值過高或過低,所以我們將這個值定為0~1之間的數。
preprocessing_function: 將被應用于每個輸入的函數。該函數將在任何其他修改之前運行。該函數接受一個參數,為一張圖片(秩為3的numpy array),并且輸出一個具有相同shape的numpy array。
下面看看我們對數據集增強后的一個效果,由于圖片數量太多,我們顯示其中9張圖片,增強后圖片如下:
大家可以多嘗試下每個增強后的效果,增加些感性認識,數據增強和圖片顯示代碼如下,只需要更改ImageDataGenerator中的參數,就能看到結果。
import matplotlib.pyplot as plt
? ? ? ? r"D://Learning//tensorflow_2.0//smile//datas//mouth//test",?
? ? ? ?batch_size=1,
? shuffle=False,? ? ? ? ? ? ? ? ? ? ? ? ? ?save_to_dir=r"D://Learning//tensorflow_2.0//smile//datas//mouth//model",
說完了數據增強,我們再看下ImageGenerator類下的函數flow_from_diectory。從這個函數名,我們也明白其就是從文件夾中讀取圖像。
train_generator = train_datagen.flow_from_directory(
flow_from_diectory中有如下參數:
directory:目標文件夾路徑,對于每一個類,該文件夾都要包含一個子文件夾。
target_size:整數tuple,默認為(256, 256)。圖像將被resize成該尺寸
color_mode:顏色模式,為"grayscale"和"rgb"之一,默認為"rgb",代表這些圖片是否會被轉換為單通道或三通道的圖片。
classes:可選參數,為子文件夾的列表,如['smile','neutral'],默認為None。若未提供,則該類別列表將從directory下的子文件夾名稱/結構自動推斷。每一個子文件夾都會被認為是一個新的類。(類別的順序將按照字母表順序映射到標簽值)。
class_mode: "categorical", "binary", "sparse"或None之一。默認為"categorical。該參數決定了返回的標簽數組的形式, "categorical"會返回2D的one-hot編碼標簽,"binary"返回1D的二值標簽。"sparse"返回1D的整數標簽,如果為None則不返回任何標簽,生成器將僅僅生成batch數據。
batch_size:batch數據的大小,默認32。
shuffle:是否打亂數據,默認為True。
seed:可選參數,打亂數據和進行變換時的隨機數種子。
save_to_dir:None或字符串,該參數能讓你將數據增強后的圖片保存起來,用以可視化。
save_prefix:字符串,保存數據增強后圖片時使用的前綴, 僅當設置了save_to_dir時生效。
save_format:"png"或"jpeg"之一,指定保存圖片的數據格式,默認"jpeg"。
這些參數中的directory一定要弄清楚,它是指類別文件夾的上一層文件夾,在該數據集中,類別文件夾為smile和neutral,它的上一級文件夾是train。所以director為?r"D://Learning//tensorflow_2.0//smile//data//train"
另外,class這個參數也要注意,通常我們就采用默認None,directory的子文件夾就是標簽。在該分類任務中標簽就是smile和neutral。
以上就是在TensorFlow2.0中利用Keras這個高級API來對分類任務中的數據進行預處理。另外如果您需要完成一個目標檢測等任務,則需要自定義一個類來繼承ImageDataGeneraton。具體怎么操作,請期待我們的下回關于如何利用TensorFlow2.0處理目標檢測任務的分享。
2 使用Dataset類對數據預處理
由于該方法在TensorFlow1.x版本中也有,大家可以比較查看2.0相對于1.x版本的改動地方。下面是TensorFlow2.0中使用的完整代碼:
import tensorflow as tf
#from tensorflow.contrib.data import Dataset
#from tensorflow.python.framework import dtypes
#from tensorflow.python.framework.ops import convert_to_tensor
? ? ? # 轉換成Tensor
? ? ? ?#self.img_paths=convert_to_tensor(self.img_paths, dtype=dtypes.string)
? ? ? ?#self.labels =convert_to_tensor(self.labels, dtype=dtypes.int32)
? ? ? ?#img = tf.read_file(filename)
? ? ? ?img = tf.io.read_file(filename)
? ? ? ?img = tf.image.convert_image_dtype(img, dtype=tf.float32)
? ? ? ?#img =tf.random_crop(img,[self.image_size[0],self.image_size[1],3]) ? ? ?
? ? ? ?img=tf.image.random_crop(img, [self.image_size[0], self.image_size[1], 3])
上圖中標紅色的地方是tensorFlow2.0版本與1.x版本的區別,紅色部分屬于1.X版本。主要更改在contrib部分,在tensorFlow2.0中已經刪除contrib了,其中有維護價值的模塊會被移動到別的地方,剩余的都將被刪除,這點大家務必注意。
如果您對上面代碼有任何不明白的地方請移步之前的文章:【tensorflow速成】Tensorflow圖像分類從模型自定義到測試
重要活動,本周有三AI紀念撲克牌發售中,只有不到100套的名額噢,先到先得!
有三AI紀念撲克牌
總結
本文主要介紹了如何在TensorFlow2.0中對自己的數據進行預處理。主要由兩種比較好用的方法,第一種是TensorFlow2.0中特有的,即利用Keras高級API對數據進行預處理,第二種是利用Dataset類來處理數據,它和TensorFlow1.X版本基本一致。
下期預告:使用TensorFlow構建深度學習網絡。
近期chat
今日看圖猜技術
知識匯總
有三AI生態
轉載文章請后臺聯系
侵權必究
更多請關注知乎專欄《有三AI學院》
往期精選
有三AI一周年了,說說我們的初衷,生態和愿景
總結
以上是生活随笔為你收集整理的【TensorFlow2.0】数据读取与使用方式的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 有三AI一周年了,说说我们的初衷,生态和
- 下一篇: 【AI不惑境】残差网络的前世今生与原理