Tensorflow05 簡單的全連線神經網路案例

2022-06-19 21:15:15 字數 2344 閱讀 4472

第乙個神經元網路就使用最簡單的全連線神經網路。

使用tensorflow裡的 fashion_mnist 服飾資料集 來完成此次的入門案例,建議使用 jupyter 分步執行,每步都理解掌握。

import

tensorflow as tf

from tensorflow import

keras

import

matplotlib.pyplot as plt

#載入資料集

fashion_mnist =keras.datasets.fashion_mnist

#得到訓練/測試 資料,訓練/測試 標籤

(train_images, train_labels), (test_images, test_labels) =fashion_mnist.load_data()

#檢視資料形狀

train_images.shape, train_labels.shape

plt.imshow(test_images[0])

#畫圖用 imshow !

#建立神經元模型

model =keras.sequential()

#第一層使用flatten

model.add(keras.layers.flatten(input_shape=(28, 28)))

model.add(keras.layers.dense(128, activation=tf.nn.relu))

model.add(keras.layers.dense(10, activation=tf.nn.softmax))

#檢視神經網路結構

model.summary()

#配置訓練方法,optimizer(優化器)為 經常使用的adam,損失函式使用sparse_categorical_crossentropy,注意還有不帶sparse的,則表示資料為 獨熱編碼形式的。

model.compile(optimizer=tf.optimizers.adam(), loss=tf.losses.sparse_categorical_crossentropy, metrics=['

accuracy'])

#為防止過擬合定義該類

class mycallback(tf.keras.callbacks.callback): #

繼承自 callback

def on_epoch_end(self, epoch, logs={}): #

重寫該方法

if(logs.get('

loss

') < 0.4): #

如果 loss < 0.4, 認為發生過擬合

print("

\nloss is low so cancelling training")

self.model.stop_training = true #

停止訓練

callbacks =mycallback()

#歸一化

train_images = train_images/255test_images_scaled = test_images/255

#訓練資料得到 history 物件,最後乙個引數表示自動中止訓練,類的定義在上方

history = model.fit(train_images, train_labels, epochs=5, callbacks=[callbacks])

#利用 測試資料/測試標籤 評估模型

model.evaluate(test_images_scaled, test_labels)

#**資料,並提取第乙個(0)的**結果

model.predict(test_images_scaled)[0]

對該案例**中的一些解釋:

首先這個資料集的每個元素是二維的,即這個資料集存放著若干張,每個是乙個畫素 28*28 的二維矩陣儲存。

所以我們的模型第一層使用 flatten,作用是將二維輸入資料轉換成一維的。也就是輸入層。

dense 表示全連線網路,至於引數 啟用函式 activation 在上篇部落格中有詳細解釋。

第二個 dense 是輸出層,一共有 10 個類別,所以輸出的神經元個數為 10。這層也叫輸出層。

介於輸入輸出層之間為 隱含層,這裡的隱含層只有乙個,也是 dense,這裡神經元數量128,可以自己更改,以得到更好的訓練結果。

配置模型的編譯 compile ,優化器為 adam(),損失函式為 sparse_categorical_crossentropy

自定義的 callback 的繼承類,防止過擬合。

fit 訓練資料

evaluate 利用測試集評估模型

predict **資料

TensorFlow簡單介紹

tensorflow簡單介紹 tensorflow中文社群 中文社群中是這個介紹的 tensorflow tensorflow是乙個採用資料流圖 data flow graphs 用於數值計算的開源軟體庫。節點 nodes 在圖中表示數學操作,圖中的線 edges 則表示在節點間相互聯絡的多維資料陣...

05 庫的簡單操作

執行如下命令,檢視系統庫 show databases information schema 虛擬庫,不占用磁碟空間,儲存的是資料庫啟動後的一些引數,如使用者表資訊 列資訊 許可權資訊 字元資訊等 performance schema mysql 5.5開始新增乙個資料庫 主要用於收集資料庫伺服器效...

05 Tensorflow中變數的初始化

開啟python shell,輸入import tensorflow as tf,然後可以執行以下 1 建立乙個2 3的矩陣,並讓所有元素的值為0.型別為tf.float a tf.zeros 2,3 dtype tf.float32 2 建立乙個3 4的矩陣,並讓所有元素的值為1.b tf.one...