08_04基于手写数据集_mat保存模型参数
生活随笔
收集整理的這篇文章主要介紹了
08_04基于手写数据集_mat保存模型参数
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
import os
import numpy as np
import tensorflow as tf
from scipy import io
from tensorflow.examples.tutorials.mnist import input_data# 1、設置超參數
learning_rate = 0.001
epochs = 10
batch_size = 128
test_valid_size = 512 # 用于驗證或者測試的樣本數量。
n_classes = 10
keep_probab = 0.75def conv2d_block(input_tensor, filter_w, filter_b, stride=1):"""實現 卷積 + 偏置項相加 + 激活:param input_tensor::param filter_w::param filter_b::param stride::return:"""conv = tf.nn.conv2d(input=input_tensor, filter=filter_w, strides=[1, stride, stride, 1], padding='SAME')conv = tf.nn.bias_add(conv, filter_b)conv = tf.nn.relu6(conv)return convdef maxpool(input_tensor, k=2):"""池化:param input_tensor::param k::return:"""ksize = [1, k, k, 1]strides = [1, k, k, 1]max_out = tf.nn.max_pool(value=input_tensor, ksize=ksize, strides=strides, padding='SAME')return max_outdef model(input_tensor, keep_prob, pre_trained_weights=None):""":param input_tensor: 輸入圖片的占位符:param weights::param biases::param keep_prob: 保留概率的占位符:return:""""""'w_conv1:0', 'w_conv2:0', 'w_fc1:0', 'w_logits:0', 'b_conv1:0', 'b_conv2:0', 'b_fc1:0', 'b_logits:0']"""if pre_trained_weights:W = pre_trained_weightsweights = {'conv1': tf.get_variable('w_conv1', dtype=tf.float32,initializer=W['w_conv1:0'], trainable=False),'conv2': tf.get_variable('w_conv2', dtype=tf.float32,initializer=W['w_conv2:0'], trainable=False),'fc1': tf.get_variable('w_fc1', dtype=tf.float32,initializer=W['w_fc1:0'], trainable=True),'logits': tf.get_variable('w_logits', dtype=tf.float32,initializer=W['w_logits:0'], trainable=True),}biases = {'conv1': tf.get_variable('b_conv1', dtype=tf.float32,initializer=np.reshape(W['b_conv1:0'], -1), trainable=False),'conv2': tf.get_variable('b_conv2', dtype=tf.float32,initializer=np.reshape(W['b_conv2:0'], -1), trainable=False),'fc1': tf.get_variable('b_fc1', shape=[1024], dtype=tf.float32,initializer=tf.zeros_initializer()),'logits': tf.get_variable('b_logits', shape=[n_classes], dtype=tf.float32,initializer=tf.zeros_initializer()),}else:weights = {'conv1': tf.get_variable('w_conv1', shape=[5, 5, 1, 32], dtype=tf.float32,initializer=tf.truncated_normal_initializer(stddev=0.1)),'conv2': tf.get_variable('w_conv2', shape=[5, 5, 32, 64], dtype=tf.float32,initializer=tf.truncated_normal_initializer(stddev=0.1)),'fc1': tf.get_variable('w_fc1', shape=[7 * 7 * 64, 1024], dtype=tf.float32,initializer=tf.truncated_normal_initializer(stddev=0.1)),'logits': tf.get_variable('w_logits', shape=[1024, n_classes], dtype=tf.float32,initializer=tf.truncated_normal_initializer(stddev=0.1)),}biases = {'conv1': tf.get_variable('b_conv1', shape=[32], dtype=tf.float32,initializer=tf.zeros_initializer()),'conv2': tf.get_variable('b_conv2', shape=[64], dtype=tf.float32,initializer=tf.zeros_initializer()),'fc1': tf.get_variable('b_fc1', shape=[1024], dtype=tf.float32,initializer=tf.zeros_initializer()),'logits': tf.get_variable('b_logits', shape=[n_classes], dtype=tf.float32,initializer=tf.zeros_initializer()),}# 1、卷積1 [N, 28, 28, 1] ---> [N, 28, 28, 32]conv1 = conv2d_block(input_tensor=input_tensor, filter_w=weights['conv1'], filter_b=biases['conv1'])# 2、池化1 [N, 28, 28, 32] --->[N, 14, 14, 32]pool1 = maxpool(conv1, k=2)# 3、卷積2 [N, 14, 14, 32] ---> [N, 14, 14,64]conv2 = conv2d_block(input_tensor=pool1, filter_w=weights['conv2'], filter_b=biases['conv2'])conv2 = tf.nn.dropout(conv2, keep_prob=keep_prob)# 4、池化1 [N, 14, 14,64] --->[N, 7, 7, 64]pool2 = maxpool(conv2, k=2)# 5、拉平層(flatten) [N, 7, 7, 64] ---> [N, 7*7*64]x_shape = pool2.get_shape()flatten_shape = x_shape[1] * x_shape[2] * x_shape[3]flatted = tf.reshape(pool2, shape=[-1, flatten_shape])# 6、FC1 全連接層fc1 = tf.nn.relu6(tf.matmul(flatted, weights['fc1']) + biases['fc1'])fc1 = tf.nn.dropout(fc1, keep_prob=keep_prob)# 7、logits層logits = tf.add(tf.matmul(fc1, weights['logits']), biases['logits'])with tf.variable_scope('prediction'):prediction = tf.argmax(logits, axis=1)return logits, predictiondef create_dir_path(path):if not os.path.exists(path):os.makedirs(path)print('create file path:{}'.format(path))def store_weights(sess, save_path):# todo 1、獲取所有需要持久化的變量# vars_list = tf.global_variables()vars_list = tf.trainable_variables()# 2、執行得到變量的值vars_values = sess.run(vars_list)# todo 3、將變量轉換為字典對象mdict = {}for values, var in zip(vars_values, vars_list):# 獲取變量的名字name = var.name# 賦值mdict[name] = values# todo 4、保存為matlab數據格式io.savemat(save_path, mdict)print('Saved Vars to files:{}'.format(save_path))def train():# 創建持久化文件夾checkpoint_dir = './model/mnist/matlab/ai20'create_dir_path(checkpoint_dir)graph = tf.Graph()with graph.as_default():# 1、占位符x = tf.placeholder(tf.float32, [None, 28, 28, 1], name='x')y = tf.placeholder(tf.float32, [None, 10], name='y')keep_prob = tf.placeholder_with_default(0.75, shape=None, name='keep_prob')# 2、創建模型圖weights_path = './model/mnist/matlab/ai20'files = os.listdir(weights_path)if files:weight_file = os.path.join(weights_path, files[0])if os.path.isfile(weight_file):mdict = io.loadmat(weight_file)logits, prediction = model(x, keep_prob, pre_trained_weights=mdict)print('Load old model continue to train!')else:logits, prediction = model(x, keep_prob)print('No old model, train from scratch!')# 3、損失loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=y))# 優化器optimizer = tf.train.AdamOptimizer(learning_rate)train_opt = optimizer.minimize(loss)# 計算準確率correct_pred = tf.equal(tf.argmax(y, axis=1), prediction)accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))with tf.Session(graph=graph) as sess:sess.run(tf.global_variables_initializer())mnist = input_data.read_data_sets('../datas/mnist', one_hot=True, reshape=False)# print(mnist.train.num_examples)step = 1while True:# 執行訓練batch_x, batch_y = mnist.train.next_batch(batch_size=batch_size)feed = {x: batch_x, y: batch_y}_, train_loss, train_acc = sess.run([train_opt, loss, accuracy], feed)print('Step:{} - Train Loss:{:.5f} - Train acc:{:.5f}'.format(step, train_loss, train_acc))# 持久化# if step % 100 == 0:# files = 'model_{:.3f}.mat'.format(train_acc)# save_file = os.path.join(checkpoint_dir, files)# store_weights(sess, save_path=save_file)step += 1# 退出機制if train_acc >0.99:breakif __name__ == '__main__':train()
總結
以上是生活随笔為你收集整理的08_04基于手写数据集_mat保存模型参数的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 34岁Java程序员裸辞,mysql实战
- 下一篇: 使用支付宝原生插件(hbuilderX)