pytorch rnn 變長輸入序列問題

2021-09-19 17:27:50 字數 4046 閱讀 6640

輸入資料是長度不固定的序列資料,主要講解兩個部分

data.dataloader的collate_fn用法,以及按batch進行padding資料

pack_padded_sequence和pad_packed_sequence來處理變長序列

dataloader的collate_fn引數,定義資料處理和合併成batch的方式。

由於pack_padded_sequence用到的tensor必須按照長度從大到小排過序的,所以在collate_fn中,需要完成兩件事,一是把當前batch的樣本按照當前batch最大長度進行padding,二是將padding後的資料從大到小進行排序。

def

pad_tensor

(vec, pad)

:"""

args:

vec - tensor to pad

pad - the size to pad to

return:

a new tensor padded to 'pad'

"""return torch.cat(

[vec, torch.zeros(pad -

len(vec)

, dtype=torch.

float)]

, dim=0)

.data.numpy(

)class

collate

:"""

a variant of callate_fn that pads according to the longest sequence in

a batch of sequences

"""def__init__

(self)

:pass

def_collate

(self, batch)

:"""

args:

batch - list of (tensor, label)

reutrn:

xs - a tensor of all examples in 'batch' before padding like:

'''[tensor([1,2,3,4]),

tensor([1,2]),

tensor([1,2,3,4,5])]

'''ys - a longtensor of all labels in batch like:

'''[1,0,1]

'''"""

xs =

[torch.floattensor(v[0]

)for v in batch]

ys = torch.longtensor(

[v[1

]for v in batch]

)# 獲得每個樣本的序列長度

seq_lengths = torch.longtensor(

[v for v in

map(

len, xs)])

max_len =

max(

[len

(v)for v in xs]

)# 每個樣本都padding到當前batch的最大長度

xs = torch.floattensor(

[pad_tensor(v, max_len)

for v in xs]

)# 把xs和ys按照序列長度從大到小排序

seq_lengths, perm_idx = seq_lengths.sort(

0, descending=

true

) xs = xs[perm_idx]

ys = ys[perm_idx]

return xs, seq_lengths, ys

def__call__

(self, batch)

:return self._collate(batch)

定義完collate類以後,在dataloader中直接使用

train_data = data.dataloader(dataset=train_dataset, batch_size=

32, num_workers=

0, collate_fn=collate(

))

pack_padded_sequence將乙個填充過的變長序列壓緊。輸入引數包括

from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

# x是填充過後的batch資料,seq_lengths是每個樣本的序列長度

packed_input = pack_padded_sequence(x, seq_lengths, batch_first=

true

)

定義了乙個單向的lstm模型,因為處理的是變長序列,forward函式傳入的值是乙個packedsequence物件,返回值也是乙個packedsequence物件

class

model

(nn.module)

:def

__init__

(self, in_size, hid_size, n_layer, drop=

0.1, bi=

false):

super

(model, self)

.__init__(

) self.lstm = nn.lstm(input_size=in_size,

hidden_size=hid_size,

num_layers=n_layer,

batch_first=

true

, dropout=drop,

bidirectional=bi)

# 分類類別數目為2

self.fc = nn.linear(in_features=hid_size, out_features=2)

defforward

(self, x)

:'''

:param x: 變長序列時,x是乙個packedsequence物件

:return: packedsequence物件

'''# lstm_out: tensor of shape (batch, seq_len, num_directions * hidden_size)

lstm_out, _ = self.lstm(x)

return lstm_out

model = model(

)lstm_out = model(packed_input)

這個操作和pack_padded_sequence()是相反的,把壓緊的序列再填充回來。因為前面提到的lstm模型傳入和返回的都是packedsequence物件,所以我們如果想要把返回的packedsequence物件轉換回tensor,就需要用到pad_packed_sequence函式。

引數說明:

返回值: 乙個tuple,包含被填充後的序列,和batch中序列的長度列表。

用法:

# 此處lstm_out是乙個packedsequence物件

output, _ = pad_packed_sequence(lstm_out)

返回的output是乙個形狀為(batch_size,seq_len,input_size)的tensor。

pytorch在自定義dataset時,可以在dataloader的collate_fn引數中定義對資料的變換,操作以及合成batch的方式。

處理變長rnn問題時,通過pack_padded_sequence()將填充的batch資料轉換成packedsequence物件,直接傳入rnn模型中。通過pad_packed_sequence()來將rnn模型輸出的packedsequence物件轉換回相應的tensor。

pytorch rnn 變長輸入序列問題

輸入資料是長度不固定的序列資料,主要講解兩個部分 data.dataloader的collate fn用法,以及按batch進行padding資料 pack padded sequence和pad packed sequence來處理變長序列 dataloader的collate fn引數,定義資料...

可變長字串

目錄stringbuilder 其他可變長字串,jdk1.0提供,執行效率慢,執行緒安全字串緩衝區 執行緒安全的可變字串 字串行 字串 如果字串需要頻繁修改,可用stringbuffer構造方法stringbuffer 初始容量為16個字元 stringbuffer int capacity 構造乙...

struct 封裝變長字串

使用struct,可以非常方便的處理二進位制資料,將常用的int,string等型別的資料轉成二進位制資料,它有兩個重要函式,乙個是pack,乙個是unpack 先看一張表 struct中支援的格式如下表 format c type python 位元組數x pad byte no value1c ...