tensorflow mnist新手文件

2021-09-29 23:08:36 字數 2430 閱讀 2344

官方文件

minist資料庫

每一張對應28x28大小的灰度圖,也就是大小為784,55,000張訓練資料[55000, 784],10,000測試資料,每個資料對應乙個label標籤(0到9)

但是label用的是one-hot vectors(獨熱編碼)格式[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],那麼訓練資料的標籤庫就是[55000, 10]資料為浮點。

tensorflow.placeholder(dtype, shape=none, name=none), 

dtype:資料型別。常用的是tf.float32,tf.float64等數值型別

shape:資料形狀。預設是none,就是一維值,也可以是多維,比如[2,3], [none, 3]表示列是3,行不定

name:名稱

import tensorflow as tf

讀mnist資料

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("mnist_data/", one_hot=true)

x = tf.placeholder(tf.float32, [none, 784])

w = tf.variable(tf.zeros([784, 10]))

b = tf.variable(tf.zeros([10]))

wx + b = y

因為y是55000x10,所以w要是784x10, b要為1x10

公式用tf表示就是y = tf.nn.softmax(tf.matmul(x, w) + b)

這裡建議用tf.nn.softmax_cross_entropy_with_logits (tf.matmul(x, w) + b) ,更穩定

常規的損失函式

y為實際輸出,label為理論的

差的平方和 sum((y - label)^2)

交叉熵 -sum(label * log(y))   -- 最常用的計算標籤(label)與輸出(y)之間差別的方法

[0, 0, 1] 與 [0.1, 0.3, 0.6]的交叉熵為 -log(0.6) = 0.51

[0, 0, 1] 與 [0.2, 0.2, 0.6]的交叉熵為 -log(0.6) = 0.51

[0, 0, 1] 與 [0.1, 0, 0.9]的交叉熵為 -log(0.9) = 0.10

當label為0時,交叉熵為0,label為1時,交叉熵為-log(y),交叉熵只關注獨熱編碼中有效位的損失。這樣遮蔽了無效位值的變化(無效位的值的變化並不會影響最終結果),並且通過取對數放大了有效位的損失。當有效位的值趨近於0時,交叉熵趨近於正無窮大。

label資料 y_ = tf.placeholder(tf.float32, [none, 10])

reduce_mean即求正確率,值越低說明交叉熵越低說明越相似

cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))

訓練,用梯度下降法,使得cross_entropy盡量最小,學習率為0.5,這裡tf還有很多其他的方法

train_step = tf.train.gradientdescentoptimizer(0.5).minimize(cross_entropy)

載入模型

sess = tf.interactivesession()

初始化我們建立的變數

tf.global_variables_initializer().run()

把batch_xs讀入x成為訓練資料 , batch_ys讀入y_成為訓練的標籤

for _ in range(1000):

batch_xs, batch_ys = mnist.train.next_batch(100)

sess.run(train_step, feed_dict=)

計算相似性

# argmax 返回最大值的下標,最大值的下標即答案

# 例如 [0,0,0,0.9,0,0.1,0,0,0,0] 代表數字3

correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))

這裡返回一串的bool值,比如[true, true, true, false, false, true]

# correct_prediction  -> [true, true, true, false, false, true]  這個總6個,4個true就是2/3的正確率

# reduce_mean即求predict的平均數 即 正確個數 / 總數,即正確率

accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

Tensorflow MNIST 手寫識別

這是乙個系列,記錄我tensorflow開發常用的 小常識,有些是參考網上 講的可能有點爛,求不要打臉,嚶嚶嚶 送給那些需要的人。可以相互交流,喜歡的加我吧。wx lxp911221 根據官方文件敲 mnist機器學習入門 第乙個例子,碰到乙個煩了我很久的問題 input data 一直找不到。wh...

Tensorflow MNIST簡單全連線層分類

time 2020 1 5 22 39 author x1aolata file mnist train.py script 訓練簡單手寫數字識別模型 直接全連線 用於測試模型儲存與轉化 from future import print function import tensorflow as t...

linux新手入門 3 檔案,壓縮

首先我們隨便建立個使用者 用 useradd tom 加乙個tom使用者 然後給他加密碼 passwd tom 超級管理員 普通使用者 cd 回家目錄 如果 是管理員就是回到root目錄 如果普通使用者就是回到 home裡的 使用者 cd 回之前目錄 su root 到超級管理員 退出就是 exit...