從dcgan中了解到了反卷積的操作,所以我本來打算能通過卷積操作作為編碼器將一幀影象轉換為乙個20維的向量,而後再通過反卷積實現解碼功能從而達到影象恢復效果,先把程式貼上,後續有空再調整網路層數和引數吧
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import os
import numpy as np
import matplotlib.pyplot as plt
mnist = input_data.read_data_sets("/homemnist/raw/",one_hot=true)
sess = tf.interactivesession()
x = tf.placeholder(tf.float32,[none,784])
x_image = tf.reshape(x,[-1,28,28,1])
def conv2d(x, w):
return tf.nn.conv2d(x, w, strides=[1, 1, 1, 1],padding='same')
def max_pool_2x2(x):
return tf.nn.max_pool(x, ksize=[1,2,2,1],strides=[1,2,2,1],padding='same')
def deconv2d(x,w,shape):
#w = tf.get_variable('w', [3, 3, shape[-1], x.get_shape()[-1]],
# initializer=tf.random_normal_initializer(stddev=0.02))
return tf.nn.conv2d_transpose(x, w ,output_shape=shape,strides=[1,2,2,1],padding='same')
w_enconv1 = tf.variable(tf.truncated_normal([3,3,1,16],stddev=0.1),name = 'w_1') #[5,5,1,32]表示卷積核尺寸5*5,1通道,32個不同卷積核
b_enconv1 = tf.variable( tf.constant(0.1, shape=[16]),name = 'b_1')# bias_variable([32])
w_enconv2 = tf.variable(tf.truncated_normal([3,3,16,8],stddev=0.1),name = 'w_2')
b_enconv2 = tf.variable( tf.constant(0.1, shape=[8]),name = 'b_2')
w_enconv3 = tf.variable(tf.truncated_normal([3,3,8,1],stddev=0.1),name = 'w_3')
b_enconv3 = tf.variable( tf.constant(0.1, shape=[1]),name = 'b_3')
w_fc = tf.variable(tf.random_normal([49,20], stddev=0.01),name = 'w_4')
w_defc = tf.variable(tf.random_normal([20,49], stddev=0.01),name = 'w_5')#[5,5,1,32]表示卷積核尺寸5*5,1通道,32個不同卷積核
b_defc = tf.variable( tf.constant(0.1, shape=[49]),name = 'b_5')
w_deconv2 = tf.variable(tf.truncated_normal([3,3,64,1],stddev=0.1),name = 'w_6')
w_deconv3 = tf.variable(tf.truncated_normal([3,3,1,64],stddev=0.1),name = 'w_7')
def encoder(x_image,w_enconv1,b_enconv1,w_enconv2,b_enconv2,w_enconv3,b_enconv3,w_fc):
h_conv1 = tf.nn.relu(conv2d(x_image,w_enconv1) + b_enconv1)
h_pool1 = max_pool_2x2(h_conv1)
h_conv2 = tf.nn.relu(conv2d(h_pool1,w_enconv2) + b_enconv2)
h_pool2 = max_pool_2x2(h_conv2)
h_conv3 = tf.nn.relu(conv2d(h_pool2,w_enconv3) + b_enconv3)
#conv_shape = h_pool3.get_shape().as_list()
#nodes = conv_shape[1]*conv_shape[2]*conv_shape[3] # 向量的長度為矩陣的長寬及深度的乘積
h_f = tf.reshape(h_conv3,[-1,49]) # conv_shape[0]為乙個batch中資料的個數
h_fc = tf.nn.relu(tf.matmul(h_f, w_fc))
return h_fc
def decoder(x,w_defc,b_defc,w_deconv2,w_deconv3):
h_0 = tf.nn.relu(tf.add(tf.matmul(x, w_defc),b_defc))
h_1 = tf.reshape(h_0,[-1,7,7,1])
h_deconv1 = tf.nn.sigmoid(deconv2d(h_1,w_deconv2,[batch_size,14,14,64]),name = 'g_h1')
h_deconv2 = tf.nn.sigmoid(deconv2d(h_deconv1,w_deconv3,[batch_size,28,28,1]),name = 'g_h2')
return h_deconv2
learning_rate = 0.01
epochs = 100
batch_size = 100
display_step = 5
encoder_op = encoder(x_image,w_enconv1,b_enconv1,w_enconv2,b_enconv2,w_enconv3,b_enconv3,w_fc)
decoder_op = decoder(encoder_op,w_defc,b_defc,w_deconv2,w_deconv3)
y_pred = decoder_op
y_true = x_image
loss = tf.reduce_mean(tf.pow(y_true-y_pred, 2))
optimizer = tf.train.rmspropoptimizer(learning_rate).minimize(loss)
with tf.session() as sess:
tf.global_variables_initializer().run()
#sess.run(init)
total_batch = int(mnist.train.num_examples/batch_size)
for epoch in range (epochs):
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
c = sess.run([optimizer,loss],feed_dict= )
if epoch % display_step == 0:
print("epoch:",'%04d'%(epoch+1))
#print("epoch:",'%04d'%(epoch+1),"cost = ","".format(c))
print("over!")
fh = mnist.test.images[:batch_size]
encode_decoder = sess.run(y_pred, feed_dict=)
plt.subplot(1,2,1);
plt.imshow(np.reshape(fh[1],(28,28)))
plt.subplot(1,2,2);
plt.imshow(np.reshape(encode_decoder[1],(28,28)))
反卷積實現 tensorflow 實現
deconv解卷積,實際是叫做conv transpose,conv transpose實際是卷積的乙個逆向過程,tf中,編寫conv transpose 的時候,心中想著乙個正向的卷積過程會很有幫助。想象一下我們有乙個正向卷積 input shape 1,5,5,3 kernel shape 2,...
TensorFlow實現MNIST的卷積神經網路
在此篇部落格中我們將學習使用tensorflow搭建乙個卷積神經網路 cnn 模型,並使用它來訓練mnist資料集。構建乙個cnn模型需要以下幾個步驟 import tensorflow as tf import numpy as np from tensorflow.examples.tutori...
tensorflow實現簡單的卷積網路
import tensorflow as tf import gc from tensorflow.examples.tutorials.mnist import input data mnist input data.read data sets f zxy python mnist data o...