【深度学习-微调模型】使用Tensorflow Slim fine-tune(微调)模型
本文主要講解在現有常用模型基礎上,如何微調模型,減少訓練時間,同時保持模型檢測精度。
首先介紹下Slim這個Google公布的圖像分類工具包,可在github鏈接:modules and examples built with tensorflow 中找到slim包。
上面這個鏈接目錄下主要包含:
official models(這個是用Tensorflow高層API做的例子模型集,建議初學者可嘗試);
research models(這個是很多研究者利用tensorflow做的模型集,這個不是官方提供的,是研究者個人在維護的);
samples folder (包含代碼片段和小的模型用以表述tensorflow特性,包含以博客形式存在的代碼呈現);
而我說的slim工具包就在research文件夾下。
Slim庫結構
不僅定義了很多接口,還提供了很多ImageNet數據集上常用的網絡結構和預訓練模型(包括Alexnet,CycleGAN,DCGAN,VGG16,VGG19,Inception V1~V4,ResNet 50, ResNet 101,MobileNet V1等)。
?
下面用slim工具包中的文件來對自己的數據集做訓練,訓練可分為利用已有的模型架構(如常見的VGG,Inception等的卷積,池化這些結構)來全新訓練權重文件或是微調權重文件。由于很多已有的imagenet圖像數據覆蓋面已經很廣,基于此訓練的網絡權重已經能提取大致的目標特征(從低微像素到高維的結構特征),所以可使用fine-tune只訓練框架中某些層的權重,當然根據自己數據集做全部權重重新訓練的檢測效果理論會更好些,需要權衡時間成本和檢測精度的需求了;
下面會依據成熟網絡結構Incvption V3分別做權重文件的全部重新訓練和部分重新訓練(即fine-tune)來介紹;
(前提是你將slim工具庫下載下來,安裝了必要的tensorflow等框架;并且根據訓練圖像制作完成了tfrecord文件)
有關tfrecord訓練文件的制作請參考:將圖像制作成tfrecord
step1:定義新的datasets數據集文件
在slim/datasets/文件夾下 添加一個python文件,直接復制一份flowers.py,重命名為“satellite.py”(這個名字可根據你實際的數據集名字來更改,我用的是何大神的航拍圖數據集)
需要對賦值生成后的satellite.py內容做如下修改:
_FILE_PATTERN = 'flowers_%s_*.tfrecord'?
更改為
_FILE_PATTERN = 'satellite_%s_*.tfrecord' ? ?? #這個主要是根據你之前制作的tfrecord文件名來改的,我制作的訓練文件為satellite_train_00000-of-00002.tfrecord和satellite_train_00001-of-00002.tfrecord,驗證文件為satellite_validation_00000-of-00002.tfrecord,satellite_validation_00001-of-00002.tfrecord。
SPLITS_TO_SIZES = {'train': 3320, 'validation': 350}
更改為
SPLITS_TO_SIZES = {'train': 4800, 'validation': 1200} ?#這個根據自己訓練和驗證樣本數量來改,我的訓練數據是800張圖/類,共6類,驗證集時200張/類,共6類;
_NUM_CLASSES = 5
更改為
_NUM_CLASSES = 6 ? ? ? #實際訓練類別為6類;
?
還需要對satellite.py文件中的'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),這行代碼做更改,由于用的數據集源文件都是XXXX.jpg格式,因此將默認的圖像格式轉為jpg,更改后為'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'), 至此,對satellite.py文件完成制作與更改(其源碼如下):
satellite.py
# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Provides data for the flowers dataset.The dataset scripts used to create the dataset can be found at: tensorflow/models/slim/datasets/download_and_convert_flowers.py """from __future__ import absolute_import from __future__ import division from __future__ import print_functionimport os import tensorflow as tffrom datasets import dataset_utilsslim = tf.contrib.slim_FILE_PATTERN = 'satellite_%s_*.tfrecord'SPLITS_TO_SIZES = {'train': 4800, 'validation': 1200}_NUM_CLASSES = 6_ITEMS_TO_DESCRIPTIONS = {'image': 'A color image of varying size.','label': 'A single integer between 0 and 4', }def get_split(split_name, dataset_dir, file_pattern=None, reader=None):"""Gets a dataset tuple with instructions for reading flowers.Args:split_name: A train/validation split name.dataset_dir: The base directory of the dataset sources.file_pattern: The file pattern to use when matching the dataset sources.It is assumed that the pattern contains a '%s' string so that the splitname can be inserted.reader: The TensorFlow reader type.Returns:A `Dataset` namedtuple.Raises:ValueError: if `split_name` is not a valid train/validation split."""if split_name not in SPLITS_TO_SIZES:raise ValueError('split name %s was not recognized.' % split_name)if not file_pattern:file_pattern = _FILE_PATTERNfile_pattern = os.path.join(dataset_dir, file_pattern % split_name)# Allowing None in the signature so that dataset_factory can use the default.if reader is None:reader = tf.TFRecordReaderkeys_to_features = {'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'),'image/class/label': tf.FixedLenFeature([], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),}items_to_handlers = {'image': slim.tfexample_decoder.Image(),'label': slim.tfexample_decoder.Tensor('image/class/label'),}decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)labels_to_names = Noneif dataset_utils.has_labels(dataset_dir):labels_to_names = dataset_utils.read_label_file(dataset_dir)return slim.dataset.Dataset(data_sources=file_pattern,reader=reader,decoder=decoder,num_samples=SPLITS_TO_SIZES[split_name],items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,num_classes=_NUM_CLASSES,labels_to_names=labels_to_names)step2:注冊數據庫
接下來對slim/datasets/dataset_factory.py文件做更改,注冊下satellite數據庫;修改之處如下(添加了兩行紅色字體代碼):
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from datasets import cifar10
from datasets import flowers
from datasets import imagenet
from datasets import mnist
from datasets import satellite
datasets_map = {
? ? 'cifar10': cifar10,
? ? 'flowers': flowers,
? ? 'imagenet': imagenet,
? ? 'mnist': mnist,
?? ?'satellite': satellite,
?? ?
}
step3:準備訓練文件夾
在slim文件夾下新建如下目錄文件夾,并將對應的文件放在相應目錄下
slim/
? ? satellite/
? ? ? ? ? ? ? data/
? ? ? ? ? ? ? ? ? ?satellite_train_00000-of-00002.tfrecord
? ? ? ? ? ? ? ? ? ?satellite_train_00001-of-00002.tfrecord
? ? ? ? ? ? ? ? ? ?satellite_validation_00000-of-00002.tfrecord
? ? ? ? ? ? ? ? ? ?satellite_validation_00001-of-00002.tfrecord
? ? ? ? ? ? ? ? ? ?label.txt
? ? ? ? ? ? ? pretrained/
? ? ? ? ? ? ? ? ? ?inception_v3.ckpt
? ? ? ? ? ? ? train_dir/
data文件夾下存放你制作的tfrecord訓練測試文件和標簽名;
pretrained文件夾下存放官網訓練的權重文件;下載地址:http:/!download. tensorflow .org/models/inception _ v3_2016 _ 08 _ 28.tar.gz ? ? ?
train_dir文件夾下存放你訓練得到的模型和日志;
step4-1:在現有模型結構上fine-tune
開始訓練,在slim文件夾下,運行如下指令可開始訓練(主要是訓練邏輯層):
python train_image_classifier.py \--train_dir=satellite/train_dir \--dataset_name=satellite \--dataset_split_name=train \--dataset_dir=satellite/data \--model_name=inception_v3 \--checkpoint_path=satellite/pretrained/inception_v3.ckpt \--checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \--trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \--max_number_of_steps=100000 \--batch_size=32 \--learning_rate=0.001 \--learning_rate_decay_type=fixed \--save_interval_secs=300 \--save_summaries_secs=2 \--log_every_n_steps=10 \--optimizer=rmsprop \--weight_decay=0.00004命令參數解析如下:
? --trainable_ scopes=Inception V3/Logits,InceptionV3/ AuxLogits :首先來解 釋參數trainable_scopes 的作用,因為非常重要。 trainable_scopes 規定了在模型中fine-tune變量的范圍 。 這里的設定表示只對 InceptionV3/Logits, Inception V3/ AuxLogits 兩個變量進行微調,其他變量都保持不動 。 Inception V3/Logits,Inception V3/ AuxLogits 就相當于在網絡中的 fc8 ,它們是 Inception V3的“末端層” 。 如果不設定 trainable_scopes , 就會對模型中所有的參數進行訓練。
? --train_dir=satellite/train_dir:表明會在 satellite/train_dir目錄下保存日志和checkpoint。
? --dataset_name=satellite、 --dataset_split_ name=train: 指定訓練的數據集。
? --dataset_dit=satellite/data:指定訓練數據集保存的位置。?
? --model_ name=inception _ v3 :使用的模型名稱。?
? --checkpoint_path=satellite/pretrained/inception_v3.ckpt:預訓練模型的保存位置。
? --checkpoint_exclude_scopes=Inception V3/Logits,InceptionV3/ AuxLogits : 在恢復預訓練模型時,不恢復這兩層。正如之前所說,這兩層是 Inception V3 模型的末端層,對應著 ImageNet 數據集的 1000 類,和相當前的數據集不符,因此不要去恢復它。
? --max_number_of_steps 100000:最大的執行步數。
? --batch_size=32:每步使用的 batch 數量。
? --learning_rate=0.001 : 學習率。
? --learning_rate_decay_type=fixed:學習率是否自動下降,此處使用固定的學習率。
? --save_interval_secs=300:每隔 300s,程序會把當前模型保存到train_dir中。 此處就是目錄 satellite/train_dir。
? --save_summaries_secs=2:每隔 2s,就會將日志寫入到 train_dir 中。可以用 TensorBoard 查看該日志。此處為了方便觀察,設定的時間間隔較多,實際訓練時,為了性能考慮,可以設定較長的時間間隔。
? --log_every_n_steps=10:每隔10步,就會在屏上打出訓練信息。
? --optimizer=msprop:表示選定的優化器。
? --weight_decay=0.00004:選定的 weight_decay 值。 即模型中所高參數的 二次正則化超參數。
以上命令是只訓練末端層 InceptionV3/Logits,Inception V3/ AuxLogits ,還 可以使用以下命令對所高層進行訓練:
step4-2:訓練整個模型權重數據
使用以下命令對所有層進行訓練:
去掉 了--trainable_scopes 參數
當train_image_classifier.py程序啟動后,如果訓練文件夾(即satellite/train_dir)里沒再已經保存的模型,就會加載 checkpoint_path 中的預訓練模型,緊接著,程序會把初始模型保存到 train_dir中 ,命名為 model.ckpt-0, 0 表示第 0 步。 這之后,每隔 5min (參數一save_interval_secs=300 指定了每隔 300s 保存一次,即 5min )。 程序還會把當前模型保存到同樣的文件夾中 , 命名恪式和第一次保存的格式一樣。 因為模型比較大,程序只會保留最新的 5 個模型。
此外,如果中斷了程序并再次運行,程序會首先檢查 train_dir 中有無已經保存的模型,如果有,就不會去加載 checkpoint_path 中的預訓練模型, 而是直接加載 train_dir 中已經訓練好的模型,并以此為起點進行訓練。 Slim 之所以這樣設計,是為了在微調網絡的時候,可以方便地按階段手動調整學習率等參數。
?
至此用slim工具包做fine-tune或重新訓練的步驟就完成了。
相似文章參考:https://blog.csdn.net/chaipp0607/article/details/74139895
總結
以上是生活随笔為你收集整理的【深度学习-微调模型】使用Tensorflow Slim fine-tune(微调)模型的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 工作介绍xml书包文件
- 下一篇: 为什么厉害的人(我)都精力那么好?我有四