這個類的例項不能手動建立。它們只能被pack_padded_sequence() 例項化。
torch.nn.utils.rnn.pack_padded_sequence()**輸入:input:[seq_length x batch_size x input_size] 或 [batch_size x seq_length x input_size],input中的seq要按照長度遞減的方式排列。
lengths:seq的長度列表,是乙個遞減的列表,與input裡的seq長度對應。ie. [5,4,1]
batch_first:bool變數,當它為true時,表示input為這種輸入形式[batch_size x seq_length x input_size],否則為另一種。
輸出:
乙個packedsequence物件,包含乙個variable型別的data,和鍊錶型別的batch_sizes。
batch的每乙個元素,代表data中,多少行為乙個batch。
例如:輸入為
input
variable containing:
(0 ,.,.) =
123(1 ,.,.) =
100[torch.floattensor of size 2x3x1]
lengths = [3, 1]
為了實現壓縮編碼,即把填充去除。我們最終的輸出為
packedsequence(data=variable containing: 11
23[torch.floattensor of size 4x1]
, batch_sizes=[2, 1, 1])
這就表明,前兩個1屬於乙個batch,後面兩個分別屬於不同的batch。換句話說,從batch_sizes可以看出,兩個seq的長度分別為1,3。後面的module或function可以根據batch_sizes讀取對應的資料。
**詳解
這裡我們以上面的輸入為例,研究該函式到底是怎麼實現資料壓縮的。
def
pack_padded_sequence
(input, lengths, batch_first=false):
# juge the length is > 0
if lengths[-1] <= 0:
raise valueerror("length of all samples has to be greater than 0, "
"but found an element in 'lengths' that is <=0")
# change the input into the shape of [seq_length x batch_size x input_size]
# here input is [3, 2, 1]
if batch_first:
input = input.transpose(0, 1)
steps =
batch_sizes =
# get the reversed iterator of the lengths
lengths_iter = reversed(lengths)
# here current_length == 1
current_length = next(lengths_iter)
batch_size = input.size(1)
if len(lengths) != batch_size:
raise valueerror("lengths array has incorrect size")
# here 1 indicate the 'step' start from 1
for step, step_value in enumerate(input, 1):
"""step_value == 1
1[torch.floattensor of size 2x1]
"""# juge if step to the end of a short seq
while step == current_length:
try:
new_length = next(lengths_iter)
except stopiteration:
current_length = none
break
# check the lengths if is a decrasing list
if current_length > new_length: # remember that new_length is the preceding length in the array
raise valueerror("lengths array has to be sorted in decreasing order")
# already step over a short seq, so the number of the batch should minus 1.
batch_size -= 1
current_length = new_length
if current_length is
none:
break
# here concat the list along the dim0.
return packedsequence(torch.cat(steps), batch_sizes)
nn.utils.rnn.pad_packed_sequence()這就是上乙個函式的逆操作。輸入是乙個packedsequence物件,包含batch_sizes,可以根據其對其中的data進行解耦。
pytorch總結學習系列 操作
算術操作 在pytorch中,同一種操作可能有很多種形式,下 用加法作為 加法形式 x torch.tensor 5.5,3 y torch.rand 5,3 print x y 加法形式 print torch.add x,y 還可指定輸出 result torch.empty 5,3 torch...
學習pytorch(四)簡單RNN舉例
import torch 簡單rnn學習舉例。rnn 迴圈神經網路 是把乙個線性層重複使用,適合訓練序列型的問題。單詞是乙個序列,序列的每個元素是字母。序列中的元素可以是任意維度的。實際訓練中,可以首先把序列中的元素變為合適的維度,再交給rnn層。學習 將hello 轉為 ohlol。dict e ...
pytorch總結學習系列 資料操作
在深度學習中,我們通常會頻繁地對資料進 行 操作。作為動 手學深度學習的基礎,本節將介紹如何對內 存中的資料進 行 操作。在pytorch中,torch.tensor 是儲存和變換資料的主要 工具。如果你之前 用過numpy,你會發現 tensor 和numpy的多維陣列 非常類似。然 tensor...