生活随笔
收集整理的這篇文章主要介紹了
Tensorflow实现MNIST数据自编码(3)
小編覺(jué)得挺不錯(cuò)的,現(xiàn)在分享給大家,幫大家做個(gè)參考.
前面自編碼(1)和自編碼(2)是針對(duì)高維數(shù)據(jù)維數(shù)進(jìn)行降低維數(shù)角度改進(jìn)模型,但是還需要讓這些特征具有抗干擾能力,輸入的特征數(shù)據(jù)受到干擾時(shí),生成特征依然不會(huì)怎么變化,使自動(dòng)編碼器具有更好的泛化能力
??import?tensorflow?as?tf??import?numpy?as?np??import?matplotlib.pyplot?as?plt??????from?tensorflow.examples.tutorials.mnist?import?input_data??mnist?=?input_data.read_data_sets('/data/',one_hot=True)????train_x?=?mnist.train.images??train_y?=?mnist.train.labels??test_x?=?mnist.test.images??test_y?=?mnist.test.labels????n_hidden_1?=?256???????n_input?=?784????x?=?tf.placeholder('float',[None,n_input])??y?=?tf.placeholder('float',[None,n_input])??dropout_keep_prob?=?tf.placeholder('float')????weights?=?{??????'h1':tf.Variable(tf.random_normal([n_input,n_hidden_1])),??????'h2':tf.Variable(tf.random_normal([n_hidden_1,n_hidden_1])),??????'out':tf.Variable(tf.random_normal([n_hidden_1,n_input])),??}??biases?=?{??????'b1':tf.Variable(tf.zeros([n_hidden_1])),??????'b2':tf.Variable(tf.zeros([n_hidden_1])),??????'out':tf.Variable(tf.zeros([n_input]))??}??def?denoise_auto_encoder(X,weights,biases,keep_prob):??????layer_1?=?tf.nn.sigmoid(tf.add(tf.matmul(X,weights['h1']),biases['b1']))??????layer_1out?=?tf.nn.dropout(layer_1,keep_prob)????????layer_2?=?tf.nn.sigmoid(tf.add(tf.matmul(layer_1out,weights['h2']),biases['b2']))??????layer_2out?=?tf.nn.dropout(layer_2,keep_prob)????????return?tf.nn.sigmoid(tf.matmul(layer_2out,weights['out'])+biases['out'])????reconstruction?=?denoise_auto_encoder(x,weights,biases,dropout_keep_prob)??cost?=?tf.reduce_mean(tf.pow(reconstruction-y,2))??optm?=?tf.train.AdamOptimizer(learning_rate=0.01).minimize(cost)????epochs?=?20??batch_size?=?256??disp_step?=?2????with?tf.Session()?as?sess:??????sess.run(tf.global_variables_initializer())??????print('Start?training')??????for?epoch?in?range(epochs):??????????num_batch?=?int(mnist.train.num_examples/batch_size)??????????total_cost?=?0??????????for?i?in?range(num_batch):??????????????batch_xs,batch_ys?=?mnist.train.next_batch(batch_size)??????????????batch_xs_noisy?=?batch_xs+0.3*np.random.randn(batch_size,784)??????????????feeds?=?{x:batch_xs_noisy,y:batch_xs,dropout_keep_prob:1.0}??????????????sess.run(optm,feed_dict=feeds)??????????????total_cost+=sess.run(cost,feed_dict=feeds)????????????????if?epoch%disp_step==0:??????????????????print('Epoch?%2d/%2d?average?cost:%.6f'%(epoch,epochs,total_cost/num_batch))????????print('Finished')??
總結(jié)
以上是生活随笔為你收集整理的Tensorflow实现MNIST数据自编码(3)的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
如果覺(jué)得生活随笔網(wǎng)站內(nèi)容還不錯(cuò),歡迎將生活随笔推薦給好友。