MNIST資料集分類簡單版本(詳細)

2021-10-02 03:15:14 字數 2653 閱讀 9653

import tensorflow as tf

import numpy as np

from tensorflow.examples.tutorials.mnist import input_data

#載入mnist資料集

mnist=input_data.read_data_sets(

"mnist_data"

,one_hot=

true

)batch_size=

100n_batch=mnist.train.num_examples//batch_size #//是整除的意思。計算一共有多少個批次

x=tf.placeholder(tf.float32,

[none

,784])

y=tf.placeholder(tf.float32,

[none,10

])#建立乙個簡單的神經網路(前向傳播)

w=tf.variable(tf.zeros(

[784,10

]))b=tf.variable(tf.zeros([10

]))prediction=tf.nn.softmax(tf.matmul(x,w)

+b)#二次代價函式(反向傳播)

loss=tf.reduce_mean(tf.square(y-prediction)

)train_step=tf.train.gradientdescentoptimizer(

0.2)

.minimize(loss)

init=tf.global_variables_initializer(

)#重點理解這兩句,有新東西。

correct_prediction=tf.equal(tf.argmax(y,1)

,tf.argmax(prediction,1)

)#tf.argmax(input,axis)根據axis取值的不同返回每行或者每列最大值的索引。axis為1表示取行最大值得索引。

#如果兩個值相等,返回true,結果儲存的是布林型的列表

accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32)

)#tf.cast()類似強制型別轉換,把布林型變為32位float型,然後求平均。[1,1,1,0,0,0,1,1,1,1],準確率為0.7

with tf.session(

)as sess:

sess.run(init)

for epoch in

range(21

):for batch in

range

(n_batch)

: batch_xs,batch_ys=mnist.train.next_batch(batch_size)

sess.run(train_step,feed_dict=

) acc=sess.run(accuracy,feed_dict=

)print

("iter"

+str

(epoch)

+",testing accuracy"

+str

(acc)

)

執行結果

extracting mnist_data/train-images-idx3-ubyte.gz

extracting mnist_data/train-labels-idx1-ubyte.gz

extracting mnist_data/t10k-images-idx3-ubyte.gz

extracting mnist_data/t10k-labels-idx1-ubyte.gz

iter0,testing accuracy0.8312

iter1,testing accuracy0.8706

iter2,testing accuracy0.8814

iter3,testing accuracy0.888

iter4,testing accuracy0.8938

iter5,testing accuracy0.8974

iter6,testing accuracy0.9

iter7,testing accuracy0.9021

iter8,testing accuracy0.9033

iter9,testing accuracy0.9049

iter10,testing accuracy0.906

iter11,testing accuracy0.9076

iter12,testing accuracy0.9082

iter13,testing accuracy0.9091

iter14,testing accuracy0.9095

iter15,testing accuracy0.9106

iter16,testing accuracy0.9111

iter17,testing accuracy0.9124

iter18,testing accuracy0.9132

iter19,testing accuracy0.9132

iter20,testing accuracy0.9139

準確率大概在90%,接下來使用卷積神經網路將其準確率提高。

Mnist資料集分類簡單版本

import tensorflow as tf from tensorflow.examples.tutorials.mnist import input data 載入資料集 mnist input data.read data sets mnist data one hot true 每個批次的...

MNIST資料集分類簡單版本

from tensorflow.examples.tutorials.mnist import input data 載入資料集mnist input data.read data sets data stu05 mnist data one hot true extracting data stu...

3 3實現MNIST資料集分類

import tensorflow as tf from tensorflow.examples.tutorials.mnist import input data 載入資料集 mnist input data.read data sets mnist data one hot true 定義每個批...