j(θ)=−1m∑i=1my(i)log(hθ(x(i)))+(1−y(i))log(1−hθ(x(i))),j(θ)=−1m∑i=1my(i)log(hθ(x(i)))+(1−y(i))log(1−hθ(x(i))),
以及j(θ)對j(θ)對引數θθ的偏導數(用於諸如梯度下降法等優化演算法的引數更新),如下:
∂∂θjj(θ)=1m∑i=1m(hθ(x(i))−y(i))x(i)j∂∂θjj(θ)=1m∑i=1m(hθ(x(i))−y(i))xj(i)
但是在大多**或數教程中,也就是直接給出了上面兩個公式,而未給出推導過程,而且這一過程並不是一兩步就可以得到的,這就給初學者造成了一定的困惑,所以我特意在此詳細介紹了它的推導過程,跟大家分享。因水平有限,如有錯誤,歡迎指正。
我們一共有m組已知樣本,(x(i),y(i))(x(i),y(i))表示第 ii 組資料及其對應的類別標記。其中x(i)=(1,x(i)1,x(i)2,...,x(i)p)tx(i)=(1,x1(i),x2(i),...,xp(i))t為p+1維向量(考慮偏置項),y(i)y(i)則為表示類別的乙個數:
這裡,只討論logistic回歸,輸入樣本資料x(i)=(1,x(i)1,x(i)2,...,x(i)p)tx(i)=(1,x1(i),x2(i),...,xp(i))t,模型的引數為θ=(θ0,θ1,θ2,...,θp)tθ=(θ0,θ1,θ2,...,θp)t,因此有
θtx(i):=θ0+θ1x(i)1+⋯+θpx(i)p.θtx(i):=θ0+θ1x1(i)+⋯+θpxp(i).
假設函式(hypothesis function)定義為:
hθ(x(i))=11+e−θtx(i)hθ(x(i))=11+e−θtx(i)
. 因為logistic回歸問題就是0/1的二分類問題,可以有
p(y^(i)=1|x(i);θ)=hθ(x(i))p(y^(i)=1|x(i);θ)=hθ(x(i))
p(y^(i)=0|x(i);θ)=1−hθ(x(i))p(y^(i)=0|x(i);θ)=1−hθ(x(i))
現在,我們不考慮「熵」的概念,根據下面的說明,從簡單直觀角度理解,就可以得到我們想要的損失函式:我們將概率取對數,其單調性不變,有
logp(y^(i)=1|x(i);θ)=loghθ(x(i))=log11+e−θtx(i),logp(y^(i)=1|x(i);θ)=loghθ(x(i))=log11+e−θtx(i),
logp(y^(i)=0|x(i);θ)=log(1−hθ(x(i)))=loge−θtx(i)1+e−θtx(i).logp(y^(i)=0|x(i);θ)=log(1−hθ(x(i)))=loge−θtx(i)1+e−θtx(i).
那麼對於第ii組樣本,假設函式表徵正確的組合對數概率為:
ilogp(y^(i)=1|x(i);θ)+ilogp(y^(i)=0|x(i);θ)=y(i)logp(y^(i)=1|x(i);θ)+(1−y(i))logp(y^(i)=0|x(i);θ)=y(i)log(hθ(x(i)))+(1−y(i))log(1−hθ(x(i)))ilogp(y^(i)=1|x(i);θ)+ilogp(y^(i)=0|x(i);θ)=y(i)logp(y^(i)=1|x(i);θ)+(1−y(i))logp(y^(i)=0|x(i);θ)=y(i)log(hθ(x(i)))+(1−y(i))log(1−hθ(x(i)))
其中,ii和ii為示性函式(indicative function),簡單理解為內條件成立時,取1,否則取0,這裡不贅言。
那麼對於一共mm組樣本,我們就可以得到模型對於整體訓練樣本的表現能力:
∑i=1my(i)log(hθ(x(i)))+(1−y(i))log(1−hθ(x(i)))∑i=1my(i)log(hθ(x(i)))+(1−y(i))log(1−hθ(x(i)))
由以上表徵正確的概率含義可知,我們希望其值越大,模型對資料的表達能力越好。而我們在引數更新或衡量模型優劣時是需要乙個能充分反映模型表現誤差的損失函式(loss function)或者代價函式(cost function)的,而且我們希望損失函式越小越好。由這兩個矛盾,那麼我們不妨領代價函式為上述組合對數概率的相反數:
j(θ)=−1m∑i=1my(i)log(hθ(x(i)))+(1−y(i))log(1−hθ(x(i)))j(θ)=−1m∑i=1my(i)log(hθ(x(i)))+(1−y(i))log(1−hθ(x(i)))
上式即為大名鼎鼎的交叉熵損失函式。(說明:如果熟悉「資訊熵「的概念e[−logpi]=−∑mi=1pilogpie[−logpi]=−∑i=1mpilogpi,那麼可以有助理解叉熵損失函式)
這步需要用到一些簡單的對數運算公式,這裡先以編號形式給出,下面推導過程中使用特意說明時都會在該步驟下腳標標出相應的公式編號,以保證推導的連貫性。
① logab=loga−logb logab=loga−logb
② loga+logb=log(ab) loga+logb=log(ab)
③ a=logea a=logea
另外,值得一提的是在這裡涉及的求導均為矩陣、向量的導數(矩陣微商),這裡有一篇教程總結得精簡又全面,非常棒,推薦給需要的同學。
下面開始推導:
交叉熵損失函式為:
j(θ)=−1m∑i=1my(i)log(hθ(x(i)))+(1−y(i))log(1−hθ(x(i)))j(θ)=−1m∑i=1my(i)log(hθ(x(i)))+(1−y(i))log(1−hθ(x(i)))
其中,
loghθ(x(i))=log11+e−θtx(i)=−log(1+e−θtx(i)) ,log(1−hθ(x(i)))=log(1−11+e−θtx(i))=log(e−θtx(i)1+e−θtx(i))=log(e−θtx(i))−log(1+e−θtx(i))=−θtx(i)−log(1+e−θtx(i))①③ .loghθ(x(i))=log11+e−θtx(i)=−log(1+e−θtx(i)) ,log(1−hθ(x(i)))=log(1−11+e−θtx(i))=log(e−θtx(i)1+e−θtx(i))=log(e−θtx(i))−log(1+e−θtx(i))=−θtx(i)−log(1+e−θtx(i))①③ .
由此,得到
j(θ)=−1m∑i=1m[−y(i)(log(1+e−θtx(i)))+(1−y(i))(−θtx(i)−log(1+e−θtx(i)))]=−1m∑i=1m[y(i)θtx(i)−θtx(i)−log(1+e−θtx(i))]=−1m∑i=1m[y(i)θtx(i)−logeθtx(i)−log(1+e−θtx(i))]③=−1m∑i=1m[y(i)θtx(i)−(logeθtx(i)+log(1+e−θtx(i)))]②=−1m∑i=1m[y(i)θtx(i)−log(1+eθtx(i))]j(θ)=−1m∑i=1m[−y(i)(log(1+e−θtx(i)))+(1−y(i))(−θtx(i)−log(1+e−θtx(i)))]=−1m∑i=1m[y(i)θtx(i)−θtx(i)−log(1+e−θtx(i))]=−1m∑i=1m[y(i)θtx(i)−logeθtx(i)−log(1+e−θtx(i))]③=−1m∑i=1m[y(i)θtx(i)−(logeθtx(i)+log(1+e−θtx(i)))]②=−1m∑i=1m[y(i)θtx(i)−log(1+eθtx(i))]
這次再計算j(θ)j(θ)對第jj個引數分量θjθj求偏導:
∂∂θjj(θ)=∂∂θj(1m∑i=1m[log(1+eθtx(i))−y(i)θtx(i)])=1m∑i=1m[∂∂θjlog(1+eθtx(i))−∂∂θj(y(i)θtx(i))]=1m∑i=1m⎛⎝x(i)jeθtx(i)1+eθtx(i)−y(i)x(i)j⎞⎠=1m∑i=1m(hθ(x(i))−y(i))x(i)j∂∂θjj(θ)=∂∂θj(1m∑i=1m[log(1+eθtx(i))−y(i)θtx(i)])=1m∑i=1m[∂∂θjlog(1+eθtx(i))−∂∂θj(y(i)θtx(i))]=1m∑i=1m(xj(i)eθtx(i)1+eθtx(i)−y(i)xj(i))=1m∑i=1m(hθ(x(i))−y(i))xj(i)
這就是交叉熵對引數的導數:
∂∂θjj(θ)=1m∑i=1m(hθ(x(i))−y(i))x(i)j
softmax交叉熵損失函式求導
來寫乙個softmax求導的推導過程,不僅可以給自己理清思路,還可以造福大眾,豈不美哉 softmax經常被新增在分類任務的神經網路中的輸出層,神經網路的反向傳播中關鍵的步驟就是求導,從這個過程也可以更深刻地理解反向傳播的過程,還可以對梯度傳播的問題有更多的思考。softmax 柔性最大值 函式,一...
交叉熵代價函式
交叉熵代價函式 cross entropy cost function 是用來衡量人工神經網路 ann 的 值與實際值的一種方式。與二次代價函式相比,它能更有效地促進ann的訓練。在介紹交叉熵代價函式之前,本文先簡要介紹二次代價函式,以及其存在的不足。ann的設計目的之一是為了使機器可以像人一樣學習...
交叉熵代價函式
交叉熵代價函式 cross entropy cost function 是用來衡量人工神經網路 ann 的 值與實際值的一種方式。與二次代價函式相比,它能更有效地促進ann的訓練。在介紹交叉熵代價函式之前,本文先簡要介紹二次代價函式,以及其存在的不足。ann的設計目的之一是為了使機器可以像人一樣學習...