softmax的多分類

2021-08-27 14:18:06 字數 2141 閱讀 8325

我們常見的邏輯回歸、svm等常用於解決二分類問題,對於多分類問題,比如識別手寫數字,它就需要10個分類,同樣也可以用邏輯回歸或svm,只是需要多個二分類來組成多分類,但這裡討論另外一種方式來解決多分類——softmax。

softmax的函式為

p(i)=\dfrac^exp(\theta_k^tx)}
可以看到它有多個值,所有值加起來剛好等於1,每個輸出都對映到了0到1區間,可以看成是概率問題。

θtixθitx為多個輸入,訓練其實就是為了逼近最佳的θtθt。

從下圖看,神經網路中包含了輸入層,然後通過兩個特徵層處理,最後通過softmax分析器就能得到不同條件下的概率,這裡需要分成三個類別,最終會得到y=0、y=1、y=2的概率值。

繼續看下面的圖,三個輸入通過softmax後得到乙個陣列[0.05 , 0.10 , 0.85],這就是soft的功能。

計算過程直接看下圖,其中zlizil即為θtixθitx,三個輸入的值分別為3、1、-3,ezez的值為20、2.7、0.05,再分別除以累加和得到最終的概率值,0.88、0.12、0。

對於訓練集,有y(i)∈y(i)∈,總共有k個分類。對於每個輸入x都會有對應每個類的概率,即p(y=j|x)p(y=j|x),從向量角度來看,有,

hθ(x(i))=⎡⎣⎢⎢⎢⎢⎢p(y(i)=1|x(i);θ)p(y(i)=2|x(i);θ)⋮p(y(i)=k|x(i);θ)⎤⎦⎥⎥⎥⎥⎥=1∑kj=1eθtj⋅x(i)⎡⎣⎢⎢⎢⎢⎢eθt1⋅x(i)eθt2⋅x(i)⋮eθtk⋅x(i)⎤⎦⎥⎥⎥⎥⎥hθ(x(i))=[p(y(i)=1|x(i);θ)p(y(i)=2|x(i);θ)⋮p(y(i)=k|x(i);θ)]=1∑j=1keθjt⋅x(i)[eθ1t⋅x(i)eθ2t⋅x(i)⋮eθkt⋅x(i)]

softmax的代價函式定為如下,其中包含了示性函式11,表示如果第i個樣本的類別為j則yij=1yij=1。代價函式可看成是最大化似然函式,也即是最小化負對數似然函式。

j(θ)=−1m[∑mi=1∑kj=11⋅log(p(y(i)=j|x(i);θ))]j(θ)=−1m[∑i=1m∑j=1k1⋅log(p(y(i)=j|x(i);θ))]

其中,p(y(i)=j|x(i);θ)=exp(θtix)∑kk=1exp(θtkx)p(y(i)=j|x(i);θ)=exp(θitx)∑k=1kexp(θktx)則,

j(θ)=−1m[∑mi=1∑kj=11⋅(θtjx(i)−log(∑kl=1eθtl⋅x(i)))]j(θ)=−1m[∑i=1m∑j=1k1⋅(θjtx(i)−log(∑l=1keθlt⋅x(i)))]

一般使用梯度下降優化演算法來最小化代價函式,而其中會涉及到偏導數,即θj:=θj−αδθjj(θ)θj:=θj−αδθjj(θ),則j(θ)j(θ)對θjθj求偏導,得到,

∇j(θ)∇θj=−1m∑mi=1[∇∑kj=11θtjx(i)∇θj−∇∑kj=11log(∑kl=1eθtl⋅x(i)))∇θj]∇j(θ)∇θj=−1m∑i=1m[∇∑j=1k1θjtx(i)∇θj−∇∑j=1k1log(∑l=1keθlt⋅x(i)))∇θj]

=−1m∑mi=1[1x(i)−∇∑kj=11∑kl=1eθtl⋅x(i)∑kl=1eθtl⋅x(i)∇θj]=−1m∑i=1m[1x(i)−∇∑j=1k1∑l=1keθlt⋅x(i)∑l=1keθlt⋅x(i)∇θj]

=−1m∑mi=1[1x(i)−x(i)eθtj⋅x(i)∑kl=1eθtl⋅x(i)]=−1m∑i=1m[1x(i)−x(i)eθjt⋅x(i)∑l=1keθlt⋅x(i)]

=−1m∑mi=1x(i)[1−p(y(i)=j|x(i);θ)]=−1m∑i=1mx(i)[1−p(y(i)=j|x(i);θ)]

得到代價函式對引數權重的梯度就可以優化了。

在多分類場景中可以用softmax也可以用多個二分類器組合成多分類,比如多個邏輯分類器或svm分類器等等。該使用softmax還是組合分類器,主要看分類的類別是否互斥,如果互斥則用softmax,如果不是互斥的則使用組合分類器。

softmax的多分類

from 我們常見的邏輯回歸 svm等常用於解決二分類問題,對於多分類問題,比如識別手寫數字,它就需要10個分類,同樣也可以用邏輯回歸或svm,只是需要多個二分類來組成多分類,但這裡討論另外一種方式來解決多分類 softmax。softmax的函式為 p i exp tix kk 1exp tkx ...

softmax多分類學習

softmax回歸從零開始實現 import torch import torchvision import numpy as np import sys import d2lzh pytorch as d2l 獲取資料 batch size 256 train iter,test iter d2l...

softMax交叉熵多分類引數調優之批次大小選擇

import tensorflow as tf import os import numpy as np import numpy as np os.environ tf cpp min log level 3 輸入隨機種子 myseed eval input learning rate eval ...