最近遇到這個函式,但查的中文部落格裡的解釋貌似不是很到位,這裡翻譯一下stackoverflow上的回答並加上自己的理解。
在pytorch中,只有很少幾個操作是不改變tensor的內容本身,而只是重新定義下標與元素的對應關係的。換句話說,這種操作不進行資料拷貝和資料的改變,變的是元資料。
這些操作是:
舉個栗子,在使用transpose()進行轉置操作時,pytorch並不會建立新的、轉置後的tensor,而是修改了tensor中的一些屬性(也就是元資料),使得此時的offset和stride是與轉置tensor相對應的。
轉置的tensor和原tensor的記憶體是共享的!
為了證明這一點,我們來看下面的**:
x = torch.randn(3, 2)
y = x.transpose(x, 0, 1)
x[0, 0] = 233
print(y[0, 0])
# print 233
可以看到,改變了y的元素的值的同時,x的元素的值也發生了變化。
也就是說,經過上述操作後得到的tensor,它內部資料的布局方式和從頭開始建立乙個這樣的常規的tensor的布局方式是不一樣的!於是…這就有contiguous()的用武之地了。
在上面的例子中,x是contiguous的,但y不是(因為內部資料不是通常的布局方式)。
注意不要被contiguous的字面意思「連續的」誤解,tensor中資料還是在記憶體中一塊區域裡,只是布局的問題!
當呼叫contiguous()時,會強制拷貝乙份tensor,讓它的布局和從頭建立的一毛一樣。
一般來說這一點不用太擔心,如果你沒在需要呼叫contiguous()的地方呼叫contiguous(),執行時會提示你:
runtimeerror: input is not contiguous
只要看到這個錯誤提示,加上contiguous()就好啦~
補充:pytorch之expand,gather,squeeze,sum,contiguous,softmax,max,argmax
torch.gather(input,dim,index,out=none)。對指定維進行索引。比如4*3的張量,對dim=1進行索引,那麼index的取值範圍就是0~2.
input是乙個張量,index是索引張量。input和index的size要麼全部維度都相同,要麼指定的dim那一維度值不同。輸出為和index大小相同的張量。
import torch
a=torch.tensor([[.1,.2,.3],
[1.1,1.2,1.3],
[2.1,2.2,2.3],
[3.1,3.2,3.3]])
b=torch.longtensor([[1,2,1],
[2,2,2],
[2,2,2],
[1,1,0]])
b=b.view(4,3)
print(a.gather(1,b))
print(a.gather(0,b))
c=torch.longtensor([1,2,0,1])
c=c.view(4,1)
print(a.gather(1,程式設計客棧c))
輸出:tensor([[ 0.2000, 0.3000, 0.2000],
[ 1.3000, 1.3000, 1.3000],
[ 2.3000, 2.3000, 2.3000],
[ 3.2000, 3.2000, 3.1000]])
tensor([[ 1.1000, 2.2000, 1.3000],
[ 2.1000, 2.2000, 2.3000],
[ 2.1000, 2.2000, 2.3000],
[ 1.1000, 1.2000, 0.3000]])
tensor([[ 0.2000],
[ 1.3000],
[ 2.1000],
[ 3.2000]])
將維度為1的壓縮掉。如size為(3,1,1,2),壓縮之後為(3,2)
import torch
a=torch.randn(2,1,1,3)
print(awww.cppcns.com)
print(a.squeeze())
輸出:tensor([[[[-0.2320, 0.9513, 1.1613]]],
[[[ 0.0901, 0.9613, -0.9344]]]])
tensor([[-0.2320, 0.9513, 1.1613],
[ 0.0901, 0.9613, -0.9344]])
擴充套件某個size為1的維度。如(2,2,1)擴充套件為(2,2,3)
import torch
x=torch.randn(2,2,1)
print(x)
y=x.expand(2,2,3)
print(y)
輸出:tensor([[[ 0.0608],
[ 2.2106]],
[[程式設計客棧-1.9287],
[ 0.8748]]])
tensor([[[ 0.0608, 0.0608, 0.0608],
[ 2.2106, 2.2106, 2.2106]],
[[-1.9287, -1.9287, -1.9287],
[ 0.8748, 0.8748, 0.8748]]])
size為(m,n,d)的張量,dim=1時,輸出為size為(m,d)程式設計客棧的張量
import torch
a=torch.tensor([[[1,2,3],[4,8,12]],[[1,2,3],[4,8,12]]])
print(a.sum())
print(a.sum(dim=1))
輸出:tensor(60)
tensor([[ 5, 10, 15],
[ 5, 10, 15]])
返回乙個記憶體為連續的張量,如本身就是連續的,返回它自己。一般用在view()函式之前,因為view()要求呼叫張量是連續的。
可以通過is_contiguous檢視張量記憶體是否連續。
import torch
a=torch.tensor([[[1,2,3],[4,8,12]],[[1,2,3],[4,8,12]]])
print(a.is_contiguous)
print(a.contiguous().view(4,3))
輸出:tensor([[ 1, 2, 3],
[ 4, 8, 12],
[ 1, 2,程式設計客棧 3],
[ 4, 8, 12]])
假設陣列v有c個元素。對其進行softmax等價於將v的每個元素的指數除以所有元素的指數之和。這會使值落在區間(0,1)上,並且和為1。
import torch
import torch.nn.functional as f
a=torch.tensor([[1.,1],[2,1],[3,1],[1,2],[1,3]])
b=f.softmax(a,dim=1)
print(b)
輸出:tensor([[ 0.5000, 0.5000],
[ 0.7311, 0.2689],
[ 0.8808, 0.1192],
[ 0.2689, 0.7311],
[ 0.1192, 0.8808]])
返回最大值,或指定維度的最大值以及index
import torch
a=torch.tensor([[.1,.2,.3],
[1.1,1.2,1.3],
[2.1,2.2,2.3],
[3.1,3.2,3.3]])
print(a.max(dim=1))
print(a.max())
輸出:(tensor([ 0.3000, 1.3000, 2.3000, 3.3000]), tensor([ 2, 2, 2, 2]))
tensor(3.3000)
返回最大值的index
import torch
a=torch.tensor([[.1,.2,.3],
[1.1,1.2,1.3],
[2.1,2.2,2.3],
[3.1,3.2,3.3]])
print(a.argmax(dim=1))
print(a.argmax())
輸出:tensor([ 2, 2, 2, 2])
tensor(11)
本文標題: 對pytorch 中的contiguous理解說明
本文位址:
對Pytorch中backward()函式的理解
寫在第一句 這個部落格解釋的也很好,參考了很多 pytorch中的自動求導函式backward 所需引數含義 所以切入正題 backward 函式中的引數應該怎麼理解?官方 如果需要計算導數,可以在tensor上呼叫.backward 1.如果tensor是乙個標量 即它包含乙個元素的資料 則不需要...
oracle中實現break和continue
一 continue 在 oracle 11g 之前無法使用 continue 實現退出當前迴圈的 11g中據說實現了 但是可以用一下方法模擬實現 declare 定義變數 begin fori in 1.10loop 真正的迴圈 forj in 1.1loop 假迴圈,目的是模擬出 continu...
對PyTorch中inplace欄位的全面理解
torch.nn.relu inplace true inplace true 表示進行原地操作,對上一層傳遞下來的tensor直接進行修改,如x x 3 inplace false 表示新建乙個變數儲存操作結果,如y x 3,x y inplace true 可以節省運算記憶體,不用多儲存變數。補...