Mnist資料集分類簡單版本

2021-09-27 09:59:41 字數 1712 閱讀 6326

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

#載入資料集

mnist = input_data.read_data_sets("mnist_data",one_hot=true)

#每個批次的大小

batch_size = 100

#計算一共有多少個批次

n_batch = mnist.train.num_examples //batch_size

#定義兩個placeholder

x = tf.placeholder(tf.float32,[none,784])

y = tf.placeholder(tf.float32,[none,10])

#建立乙個簡單的神經網路

w = tf.variable(tf.zeros([784,10]))

b = tf.variable(tf.zeros([10]))

prediction = tf.nn.softmax(tf.matmul(x,w)+b)

#二次代價函式

loss = tf.reduce_mean(tf.square(y-prediction))

#使用梯度下降法

train_step = tf.train.gradientdescentoptimizer(0.2).minimize(loss)

#初始化變數

init = tf.global_variables_initializer()

#結果存放在乙個布林型列表中

correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一維張量中最大的值所在的位置

#求準確率

accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

with tf.session() as sess:

sess.run(init)

for epoch in range(21):

for batch in range(n_batch):

batch_xs,batch_ys = mnist.train.next_batch(batch_size)

sess.run(train_step,feed_dict=)

acc = sess.run(accuracy,feed_dict=)

print("iter" + str(epoch) + ",testing accuracy " + str(acc))

one_hot=true有乙個0,其餘為1 

x = tf.placeholder(tf.float32,[none,784]),784表示把資料集鋪平展開為一維陣列則有一行784個資料

y = tf.placeholder(tf.float32,[none,10]), 10表示標籤有10種

經過21次訓練後,精確度由0.8334變成了0.9137,還可以繼續優化,以下有幾種優化思路:

1.可以改變每個批次的大小

2.可以改變權重w和偏置值b

3.此處訓練了21次,可以改變訓練次數

此次執行中還出現了乙個問題,就是說input_data沒有定義,原因是沒有執行最上面兩行,最上面的兩行也是需要單獨執行一遍的。

MNIST資料集分類簡單版本

from tensorflow.examples.tutorials.mnist import input data 載入資料集mnist input data.read data sets data stu05 mnist data one hot true extracting data stu...

MNIST資料集分類簡單版本(詳細)

import tensorflow as tf import numpy as np from tensorflow.examples.tutorials.mnist import input data 載入mnist資料集 mnist input data.read data sets mnist...

3 3實現MNIST資料集分類

import tensorflow as tf from tensorflow.examples.tutorials.mnist import input data 載入資料集 mnist input data.read data sets mnist data one hot true 定義每個批...