TensorFlow實戰 手寫數字識別之K近鄰

2021-08-03 07:58:08 字數 2002 閱讀 5383

自從google發布了tensorflow之後,作為一款開源的深度學習框架,在全世界範圍內產生了巨大的影響力,如今在github上深度學習框架居於第一名,且遠遠領先其他深度學習開源專案,並且也在工業界被大量運用。學習tensorflow不僅可以加深對深度學習的理解,而且可以知道如何將深度學習這一門高深的學問用於實踐當中。

tensorflow就像其名字一樣,由「tensor」和「flow」組成,「tensor」即「張量」的意思。框架的主要思想是先構建需要的計算圖,圖中每個定點表示乙個操作,邊表示張量之間的流向或依賴關係。當整個計算圖構建完之後,啟動計算圖,系統會自動按照節點之間的依賴關係計算節點值,就能在需要的節點上獲取資料。

本文並不打算詳細介紹tensorflow的原理,想要看原理的可以直接去官網。本文主要內容是用tensorflow寫乙個入門級的演算法k近鄰實現手寫數字識別mnist。

keras提供了實現深度學習所需要的絕大部分函式庫,可實現多種神經網路模型,並可載入多種資料集來評價模型的效果。下面的**會自動載入資料,如果是第一次呼叫,資料會儲存在你的hone目錄下~/.keras/datasets/mnist.pkl.gz,大約15mb。

from keras.datasets import mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()

對資料的維度進行reshape,原資料是28*28大小的,要將其展開成784長度的向量,便於計算樣本間的距離。

num_pixels = x_train.shape[1] * x_train.shape[2]

x_train = x_train.reshape(x_train.shape[0], num_pixels).astype('float32')

x_test = x_test.reshape(x_test.shape[0], num_pixels).astype('float32')

## 取一部分作為訓練資料

xtr, ytr = x_train[:5000],y_train[:5000]

xte, yte = x_test,y_test

計算訓練集中的樣本距離,採用l1距離,取距離最近的乙個樣本,將其標籤賦值給測試樣本

distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.negative(xte))), reduction_indices=1)

pred = tf.arg_min(distance, 0)

對測試集進行**

accuracy = 0.

# 初始化

init = tf.global_variables_initializer()

## 啟動session

with tf.session() as sess:

sess.run(init)

# 對每個測試樣本進行**

for i in range(len(xte)):

# 得到最近距離的樣本

nn_index = sess.run(pred, feed_dict=)

# 輸出**結果

print("test", i, "prediction:", np.argmax(ytr[nn_index]),

"true class:", np.argmax(yte[i]))

# 計算準確率

if np.argmax(ytr[nn_index]) == np.argmax(yte[i]):

accuracy += 1./len(xte)

print("測試完成!")

print("accuracy:", accuracy)

用最近鄰處理mnist問題可以取得較好的效果,需要調節的引數主要是近鄰的數目(k),模型的效果相當程度上依賴k的取值。雖然過程很簡單,但對於了解和熟悉tensorflow也很有幫助,同時也可以用tensorflow實現邏輯回歸、線性回歸等模型,後面會一一將其實現。

tensorflow實踐 手寫MNIST數字識別

import tensorflow as tf from tensorflow.examples.tutorials.mnist import input data 載入資料集,讀取的是壓縮包 mnist input data.read data sets mnist one hot true 每個...

tensorflow實現MNIST手寫數字識別

mnist資料集是由0 9,10個手寫數字組成。訓練影象有60000張,測試影象有10000張。from tensorflow.examples.tutorials.mnist import input data mnist input data.read data sets mnist data ...

tensorflow 手寫識別

coding utf 8 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input data 載入資料集 mnist input data.read data sets mnist data one h...