tensorflow實現神經網路的優化

2021-10-08 17:32:11 字數 2128 閱讀 6835

tensorflow筆記系列文章均參考自中國大學mooc上北京大學軟體與微電子學院曹建老師的《tensorflow筆記2》課程。曹建老師講的非常棒,受益良多,強烈建議tensorflow初學者學習。

神經網路複雜度多用神經網路層數和神經網路引數的個數來表示。

空間複雜度:

時間複雜度:

對於學習率來說,我們知道的是,學習率設定過小,模型收斂會很慢,學習率設定過大,模型可能不收斂。那麼怎樣設定乙個合適的學習率呢,可以使用指數衰減學習率來解決這個問題。公式如下:

指數衰減學習率 = 初始學習率*學習率衰減率**(當前輪數/多少輪衰減一次)

優秀的啟用函式:

啟用函式輸出值的範圍:

**損失函式(loss)**就是**值與已知答案的差距,神經網路的優化目標就是找到一套引數,使得損失函式的值最小。

損失函式的定義常使用均方誤差(mse),在tensorflow中可以用以下方式來定義:loss_mse = tf.reduce_mean(tf.square(y_ - y)),其中y_是真實值,y是**值。

除此之外,損失函式的定義還能使用均方誤差和交叉熵,自定義損失函式就不介紹了,下面介紹一下交叉熵函式。

**交叉熵函式ce(cross entropy)**表示兩個概率分布之間的距離:

h (y

−,y)

=−∑y

−∗ln

yh(y_ -, y) = -\sum

h(y−​,

y)=−

∑y−​

∗lny

交叉熵函式的值越小表示**的越準確,用tensorflow可以以下方式實現:tf.losses.categorical_crossentropy(y_,y)

在執行分類問題時,我們通常會使用softmax函式讓輸出結果符合概率分布,再計算損失函式的值,tensorflow給出了乙個softmax函式和交叉熵函式結合的方法:tf.nn.softmax_cross_entropy_with_logits(y_, y)

欠擬合是模型沒有很好的表現資料集的特徵,過擬合是模型對資料集的特徵表現的太好了,以至於泛化能力弱,不能很好的識別未見過的資料。

欠擬合的解決方法:

過擬合的解決方法:

通常情況下,欠擬合是更好解決的,提公升模型複雜度就是乙個很好用的方法。對於過擬合,我們有以下幾種方式:

正則化緩解過擬合:正則化在損失函式中 引入模型複雜度指標,利用給w加權值,弱化了訓練資料的雜訊(一般不正則化b),公式如下:

l os

s=lo

ss(y

,y−)

+reg

ular

zer∗

loss

(w

)loss = loss(y,y_-) + regularzer*loss(w)

loss=l

oss(

y,y−

​)+r

egul

arze

r∗lo

ss(w

)其中loss(y與y_)代表模型中所有引數的損失函式,比如交叉熵和均方誤差

regularzer表示給出引數w在總loss中的比例,也就是正則化的權重,這是乙個超引數

loss(w)表示選擇的正則化方式,w表示需要正則化的引數,正則化的選擇有l1正則化l2正則化,公式如下

l os

sl1(

w)=∑

i∣wi

∣loss_(w) = \sum_i

lossl1

​(w)

=i∑​

∣wi​

∣ l os

sl2(

w)=∑

i∣wi

2∣

loss_(w) = \sum_i

lossl2

​(w)

=i∑​

∣wi2

​∣正則化的選擇:

tensorflow實戰 實現簡單的神經網路

from tensorflow.examples.tutorials.mnist import input data import tensorflow as tf mnist input data.read data sets mnist data one hot true sess tf.int...

TensorFlow實現高階的卷積神經網路

本人使用的資料集是cifar 10。這是乙個經典的資料集,許多 也都是在這個資料集上進行訓練。使用的卷積神經網路是根據alex描述的cuda convnet模型修改得來。在這個神經網路中,我使用了一些新的技巧 1 對weights進行了l2的正則化 2 將影象進行翻轉 隨機剪下等資料增強,製造了更多...

基於tensorflow2實現卷積神經網路

利用tensorflow2中的api實現乙個簡單的卷積神經網路,完成梯度下降的操作並繪製訓練集和測試集準確率曲線。資料集在這裡 資料分布 訓練集數量為209,測試集數量為50 import numpy as np import matplotlib.pyplot as plt import tens...