交叉熵損失函式理解

2021-08-28 20:19:57 字數 1676 閱讀 9284

交叉熵損失函式的數學原理

我們知道,在二分類問題模型:例如邏輯回歸「logistic regression」、神經網路「neural network」等,真實樣本的標籤為 [0,1],分別表示負類和正類。模型的最後通常會經過乙個 sigmoid 函式,輸出乙個概率值,這個概率值反映了**為正類的可能性:概率越大,可能性越大。

sigmoid 函式的表示式和圖形如下所示:

g(s)=11+e−sg(s)=11+e−sg(s)=\frac}

其中 s 是模型上一層的輸出,sigmoid 函式有這樣的特點:s = 0 時,g(s) = 0.5;s >> 0 時, g ≈ 1,s << 0 時,g ≈ 0。顯然,g(s) 將前一級的線性輸出對映到 [0,1] 之間的數值概率上。這裡的 g(s) 就是交叉熵公式中的模型**輸出 。

我們說了,**輸出即 sigmoid 函式的輸出表徵了當前樣本標籤為 1 的概率:

y=p(y=1|x)y=p(y=1|x)\hat y=p(y=1|x)

很明顯,當前樣本標籤為 0 的概率就可以表達成:

1−y=p(y=0|x)1−y=p(y=0|x)1-\hat y=p(y=0|x)

重點來了,如果我們從極大似然性的角度出發,把上面兩種情況整合到一起:

p(y|x)=yy⋅(1−y)1−yp(y|x)=yy⋅(1−y)1−yp(y|x)=\hat y^y\cdot (1-\hat y)^

不懂極大似然估計也沒關係。我們可以這麼來看:

當真實樣本標籤 y = 0 時,上面式子第一項就為 1,概率等式轉化為:

p(y=0|x)=1−yp(y=0|x)=1−yp(y=0|x)=1-\hat y

當真實樣本標籤 y = 1 時,上面式子第二項就為 1,概率等式轉化為:

p(y=1|x)=yp(y=1|x)=yp(y=1|x)=\hat y

兩種情況下概率表示式跟之前的完全一致,只不過我們把兩種情況整合在一起了。

重點看一下整合之後的概率表示式,我們希望的是概率 p(y|x) 越大越好。首先,我們對 p(y|x) 引入 log 函式,因為 log 運算並不會影響函式本身的單調性。則有:

log p(y|x)=log(yy⋅(1−y)1−y)=ylog y+(1−y)log(1−y)log p(y|x)=log(yy⋅(1−y)1−y)=ylog y+(1−y)log(1−y)log\ p(y|x)=log(\hat y^y\cdot (1-\hat y)^)=ylog\ \hat y+(1-y)log(1-\hat y)

我們希望 log p(y|x) 越大越好,反過來,只要 log p(y|x) 的負值 -log p(y|x) 越小就行了。那我們就可以引入損失函式,且令 loss = -log p(y|x)即可。則得到損失函式為:

l=−[ylog y^+(1−y)log (1−y^)]l=−[ylog y^+(1−y)log (1−y^)]l=-[ylog\ \hat y+(1-y)log\ (1-\hat y)]

非常簡單,我們已經推導出了單個樣本的損失函式,是如果是計算 n 個樣本的總的損失函式,只要將 n 個 loss 疊加起來就可以了:

l=∑i=1ny(i)log y^(i)+(1−y(i))log (1−y^(i))l=∑i=1ny(i)log y^(i)+(1−y(i))log (1−y(i))l=\sum_ny^log\ \hat y+(1-y)log\ (1-\hat y^)

這樣,我們已經完整地實現了交叉熵損失函式的推導過程。

交叉熵損失函式 交叉熵損失函式和均方差損失函式引出

交叉熵 均方差損失函式,加正則項的損失函式,線性回歸 嶺回歸 lasso回歸等回歸問題,邏輯回歸,感知機等分類問題 經驗風險 結構風險,極大似然估計 拉普拉斯平滑估計 最大後驗概率估計 貝葉斯估計,貝葉斯公式,頻率學派 貝葉斯學派,概率 統計 記錄被這些各種概念困擾的我,今天終於理出了一些頭緒。概率...

交叉熵損失函式

公式 分類問題中,我們通常使用 交叉熵來做損失函式,在網路的後面 接上一層softmax 將數值 score 轉換成概率。如果是二分類問題,我們通常使用sigmod函式 2.為什麼使用交叉熵損失函式?如果分類問題使用 mse 均方誤差 的方式,在輸出概率接近0 或者 接近1的時候,偏導數非常的小,學...

交叉熵損失函式

監督學習的兩大種類是分類問題和回歸問題。交叉熵損失函式主要應用於分類問題。先上實現 這個函式的功能就是計算labels和logits之間的交叉熵。tf.nn.softmax cross entropy with logits logits y,labels y 首先乙個問題,什麼是交叉熵?交叉熵 c...