tensorflow實踐 手寫MNIST數字識別

2021-09-25 12:36:56 字數 2260 閱讀 8082

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

# 載入資料集,讀取的是壓縮包

mnist = input_data.read_data_sets("mnist",one_hot=true)

# 每個批次的大小

batch_size = 50

# 計算一共有多少個批次

n_batch = mnist.train.num_examples // batch_size

# 定義兩個placeholder

x = tf.placeholder(tf.float32,[none,784]) # 28*28 一張的向量

y = tf.placeholder(tf.float32,[none,10])

keep_prob = tf.placeholder(tf.float32)

lr = tf.variable(0.001,dtype=tf.float32) #定義乙個學習率

# 建立乙個神經網路

w1 = tf.variable(tf.truncated_normal([784,500],stddev=0.1))

b1 = tf.variable(tf.zeros([1,500])+0.1)

l1 = tf.nn.tanh(tf.matmul(x,w1)+b1)

l1_drop = tf.nn.dropout(l1,keep_prob)

w2 = tf.variable(tf.truncated_normal([500,300],stddev=0.1))

b2 = tf.variable(tf.zeros([1,300])+0.1)

l2 = tf.nn.tanh(tf.matmul(l1_drop,w2)+b2)

l2_drop = tf.nn.dropout(l2,keep_prob)

w3 = tf.variable(tf.truncated_normal([300,10],stddev=0.1))

b3 = tf.variable(tf.zeros([1,10])+0.1)

prediction = tf.nn.softmax(tf.matmul(l2_drop,w3)+b3)

# 二次代價函式

# loss = tf.reduce_mean(tf.square(y-prediction))

# 交叉熵

loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y,logits=prediction))

# 使用梯度下降

# train_step = tf.train.gradientdescentoptimizer(0.2).minimize(loss)

train_step = tf.train.adamoptimizer(lr).minimize(loss)

# 初始化變數

init = tf.global_variables_initializer()

# argmax返回張量中最大的值所在的位置

# 結果存放在乙個布林型列表中

correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))

#求準確率

accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) #cast轉換為0和1

with tf.session() as sess:

sess.run(init)

for epoch in range(21):

sess.run(tf.assign(lr, 0.001*(0.95**epoch))) # 迭代乙個週期給學習率重新賦值

for batch in range(n_batch):

batch_xs,batchys= mnist.train.next_batch(batch_size)

sess.run(train_step,feed_dict=)

learning_rate = sess.run(lr)

acc = sess.run(accuracy,feed_dict=)

print("iter"+str(epoch)+",testing accuracy"+str(acc)+",learning rate="+str(learning_rate))

tensorflow 手寫識別

coding utf 8 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input data 載入資料集 mnist input data.read data sets mnist data one h...

tensorFlow識別手寫數字

這些天一直疲於奔波,感慨頗多。一直在思考發展方向的問題,可能自己有點技術控。很多人都覺得讀書多就應該賺錢多,讀書的最終目的就是賺錢麼。當然讀書不可能完全脫離煙火,但我個人覺得讀書的目的不是賺錢,賺錢只是附帶產品。像什麼某某沒讀什麼書,就可以賺很多錢的例子,簡直就是屁話。有誰說過讀書的作用就是賺錢。學...

TensorFlow入門MNIST手寫識別

匯入mnist資料集 訓練集有55000個樣本 測試集有10000個樣本 同時驗證集有5000個樣本 每個樣本都有它應標註資訊,即lable from tensorflow.examples.tutorials.mnist import input data mnist input data.rea...