torch.autograd是pytorch最重要的元件,主要包括variable類和function類,variable用來封裝tensor,是計算圖上的節點,function則定義了運算操作,是計算圖上的邊。
1.tensor
tensor張量和numpy陣列的區別是它不僅可以在cpu上執行,還可以在gpu上執行。
tensor其實包含乙個資訊頭和乙個資料儲存型別torch.storage,torch.storage是乙個單一資料型別的連續一維陣列。可以用tensor.is_contiguous()檢驗張量的資料儲存是否是在連續記憶體上,tensor必須連續才能夠使用view操作改變tensor形狀,如果不連續則可以使用tensor.contiguous()使之連續。
2.variable
注:自pytorch0.4.0版本之後,variable型別和tensor型別合併,在**中不用再把tensor轉換為variable。
var =
variable
(tensor, requires_grad=true)
在計算圖中,只要有乙個節點使用了requires_grad=true,它的後續關聯節點都會成為requires_grad=true,就是說都需要計算梯度。可以
var =
variable
(tensor,
volatile
=true)
只要有乙個節點使用volatile=true整個計算圖就不會呼叫.backward(),用於推理過程。
注意在variable不支援inplace運算操作,因為這樣導致變數值被更改,反向傳播的時候無法再使用,因此為了避免計算錯誤,計算圖**現inplace操作時, pytorch會報錯。
可用tensor.is_leaf()判斷某個變數在計算圖中是否是葉子節點,只有葉子節點會保留grad,其他張量不保留,如果非葉子節點需要保留梯度,則使用tensor.retain_grad()即可。
2.1 .backward()對計算圖進行反向傳播更新梯度
對於計算圖中的乙個標量,比如損失函式的輸出loss,可以直接進行.backward()操作。如果是乙個張量,比如中間過程,則必須指定和該張量同形狀的grad_tensor,具體涉及到反向傳播過程(復合函式鏈式法則求偏導)的jacobian矩陣。
2.2 torch.nn.parameter()
parameter是variable的子類,但parameter類會出現在模型的引數列表中(即會出現在model.parameters()迭代器中),且parameter類預設requires_grad=true,且不能設定volatile。
2.3 凍結網路部分引數
可以用detach()把張量從計算圖中分離出來,分離出來的變數不求梯度,可以用來凍結部分網路權重引數(**示例待補充)。也可以通過設定網路前面部分引數的requires_grad=false來凍結網路。
model = torchvision.models.
resnet18
(pretrained=true)
for param in model.
parameters()
: param.requires_grad = false
model.fc = nn.
linear
(512
,100
)optimizer = optim.
sgd(model.fc.
parameters()
, lr=
1e-2
, momentum=
0.9)
3.function
function是對variable進行的運算,定義了forward()方法和backward()方法。可以與nn.module()對比來理解。兩者都可以實現運算,但是function無法儲存引數,用於不需要更新引數的操作,例如各種啟用函式、池化等運算,而module可以儲存引數,則用於線性層、卷積層等運算。使用function自定義運算時必須重寫forward()和backward()方法,而使用module自定義運算時,只需要寫forward()即可,backward()可由module中的各種元件自動求解了。
pytorch基礎知識整理(四) 模型
torch.nn.module 是所有網路模型的基類,所有網路都需要繼承此類,模板如下 import torch.nn as nn import torch.nn.functional as f class model nn.module def init self super init 表示繼承父...
基礎知識整理
1.在輸出字元變數的值時,可以選擇以十進位制整數形式輸出,或以字元形式輸出。2.在乙個整數的末尾加大寫字母l或小寫字母l,表示它是長整型。3.代表除法運算子,兩個實數相除的結果是雙精度實數。兩個整數相除的結果是整數,捨去小數部分。但是,如果除數或被除數中有乙個是負值,則捨入的方向是不固定的。多數c編...
c 基礎知識整理(一)
一 標頭檔案 1 define保護 為防止標頭檔案被多重包含,檔案的格式應該為 h 這樣寫是為了保證其唯一性 2 內聯函式 在編譯的時候,編譯器會將它自動展開 所以合理的使用內聯函式會提高效率 內聯函式一般都是短小的,但要除for,while這類的。有些函式即使不加了inline 也不一定會變成內聯...