Pytorch之簡潔版Softmax多分類

2021-10-24 03:52:09 字數 3819 閱讀 9910

中,我們自己手動實現了對於softmax操作和交叉熵的計算,可其實這些在pytorch框架中已經被實現了,我們直接拿來使用即可。但是,為了能夠對這些內容有著更深刻的理解,通常我們都會自己手動實現一次,然後在今後的使用中就可以直接拿現成的來用了。在接下來這篇文章中,筆者將首先介紹如何呼叫pytorch中的交叉熵損失函式,然後再同時借助nn.linear()來實現乙個簡潔版的softmax回歸。

在前一篇文章中,我們首先分別自己實現了softmax和交叉熵的操作;然後再將兩者結合實現了交叉熵損失函式的計算過程。但其實這兩步通過pytorch中的crossentropyloss()就能實現。

def

softmax

(x):

s = torch.exp(x)

return s / torch.

sum(s, dim=

1, keepdim=

true

)def

crossentropy

(y_true, logits)

: c =

-torch.log(logits.gather(

1, y_true.reshape(-1

,1))

)return torch.

sum(c)

logits = torch.tensor([[

0.5,

0.3,

0.6],[

0.5,

0.4,

0.3]])

y = torch.longtensor([2

,1])

c = crossentropy(y, softmax(logits))/

len(y)

print

(c)loss = torch.nn.crossentropyloss(reduction=

'mean'

)# 返回的均值是除以的每一批樣本的個數(不一定是batchsize,因為最後乙個batch的樣本可能很少)

cc = loss(logits, y)

print

(cc)

#結果:

tensor(

1.0374

)tensor(

1.0374

)

從上述**可以看出,僅僅用pytorch中的兩行**就能實現我們需要的功能。同時,需要注意的是當crossentropyloss中指定引數reduction='mean'時,返回的均值是總損失除以輸入的樣本數量,而不是batchsize或者batchsize*cc表示類別數)。

def

train()

: input_nodes =28*

28 output_nodes =

10 epochs =

5 lr =

0.1 batch_size =

256 train_iter, test_iter = loaddataset(batch_size)

net = nn.sequential(nn.flatten(),

nn.linear(input_nodes, output_nodes)

) loss = nn.crossentropyloss(reduction=

'mean'

) optimizer = torch.optim.sgd(net.parameters(

), lr=lr)

# 定義優化

同之前實現簡潔版的線性回歸一樣,通過pytorch封裝好的api,僅僅兩行**就能夠實現softmax分類模型。其中nn.flatten()用於將輸入「拉平」成乙個向量,在此處就是將輸入的展成乙個784維的向量。同時,在倒數第2行**中,我們還直接呼叫了pytorch中的交叉熵損失函式。最後1行**我們定義了乙個sgd優化器。接下來,就是通過乙個迴圈來對網路進行訓練:

for epoch in

range

(epochs)

:for i,

(x, y)

inenumerate

(train_iter)

: logits = net(x)

l = loss(logits, y)

optimizer.zero_grad(

) l.backward(

) optimizer.step(

)# 執行梯度下降

上述**的含義在之前的文章中已經介紹過,在此就不再贅述。同時,完整示例**可在引用[2]中進行獲取。

if __name__ ==

'__main__'

: mnist_train, mnist_test = loaddataset(

) train(mnist_train, mnist_test)

#結果epochs[5/

2]--

-batch[

234/

150]--

-acc 0.8438--

-loss 0.5209

epochs[5/

2]--

-batch[

234/

200]--

-acc 0.8086--

-loss 0.559

epochs[5/

2]--acc on test 0.8158

epochs[5/

3]--

-batch[

234/0]

---acc 0.8438--

-loss 0.4491

epochs[5/

3]--

-batch[

234/50]

---acc 0.8047--

-loss 0.5441

epochs[5/

3]--

-batch[

234/

100]--

-acc 0.8203--

-loss 0.5268

epochs[5/

3]--

-batch[

234/

150]--

-acc 0.8125--

-loss 0.4648

epochs[5/

3]--

-batch[

234/

200]--

-acc 0.875--

-loss 0.4342

epochs[5/

3]--acc on test 0.8124

在這篇文章中,筆者首先介紹了pytorch封裝好的softmax交叉熵損失函式crossentropyloss(),同時將其計算結果與我們自己實現的方法的結果進行了對比;接著筆者通過pytorch封裝好的介面快速的實現了乙個softamx分類器,同時還介紹了nn.flatten()的作用。

[1]動手深度學習

[2]示例**:

[1]pytorch之softmax多分類任務

[2]想明白多分類必須得談邏輯回歸

[3]pytorch之linear與mseloss

[4]pytorch之擬合正弦函式你會嗎?

[5]你告訴我什麼是深度學習

python爬蟲之新浪網(簡潔版)

爬蟲 python 注釋挺詳細了,直接上全部 歡迎各位大佬批評指正。from selenium import webdriver from selenium.webdriver.chrome.options import options from selenium.webdriver.common....

git操作命令(簡潔版)

git clone gitgit checkout b 分支名 git checkout d 分支名git checkout 分支名git pullgit add 把要提交的所有修改放到暫存區 一次新的提交git commit m 內容 把暫存區的所有修改提交到分支 git push origin ...

Redis個人總結簡潔版

list hash setzset set expire原子性 如果setnx和expire中間出現意外打斷,造成expire沒有得到執行,那麼這個鎖將永遠得不到釋放 超時問題 可重入問題 上述的策略都是不支援可重入鎖的 redlock 普通演算法存在的問題 在主從結構中,如果某個執行緒a剛在主節點...