輸入資料是長度不固定的序列資料,主要講解兩個部分
data.dataloader的collate_fn用法,以及按batch進行padding資料
pack_padded_sequence和pad_packed_sequence來處理變長序列
dataloader的collate_fn引數,定義資料處理和合併成batch的方式。
由於pack_padded_sequence用到的tensor必須按照長度從大到小排過序的,所以在collate_fn中,需要完成兩件事,一是把當前batch的樣本按照當前batch最大長度進行padding,二是將padding後的資料從大到小進行排序。
def pad_tensor(vec, pad):
"""args:
vec - tensor to pad
pad - the size to pad to
return:
a new tensor padded to 'pad'
"""return torch.cat([vec, torch.zeros(pad - len(vec), dtype=torch.float)], dim=0).data.numpy()
class collate:
"""a variant of callate_fn that pads according to the longest sequence in
a batch of sequences
"""def __init__(self):
pass
def _collate(self, batch):
"""args:
batch - list of (tensor, label)
reutrn:
xs - a tensor of all examples in 'batch' before padding like:
'''[tensor([1,2,3,4]),
tensor([1,2]),
tensor([1,2,3,4,5])]
'''ys - a longtensor of all labels in batch like:
'''[1,0,1]
'''"""
xs = [torch.floattensor(v[0]) for v in batch]
ys = torch.longtensor([v[1] for v in batch])
# 獲得每個樣本的序列長度
seq_lengths = torch.longtensor([v for v in map(len, xs)])
max_len = max([len(v) for v in xs])
# 每個樣本都padding到當前batch的最大長度
xs = torch.floattensor([pad_tensor(v, max_len) for v in xs])
# 把xs和ys按照序列長度從大到小排序
seq_lengths, perm_idx = seq_lengths.sort(0, descending=true)
xs = xs[perm_idx]
ys = ys[perm_idx]
return xs, seq_lengths, ys
def __call__(self, batch):
return self._collate(batch)
定義完collate類以後,在dataloader中直接使用
train_data = data.dataloader(dataset=train_dataset, batch_size=32, num_workers=0, collate_fn=collate())
pack_padded_sequence將乙個填充過的變長序列壓緊。輸入引數包括
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
# x是填充過後的batch資料,seq_lengths是每個樣本的序列長度
packed_input = pack_padded_sequence(x, seq_lengths, batch_first=true)
定義了乙個單向的lstm模型,因為處理的是變長序列,forward函式傳入的值是乙個packedsequence物件,返回值也是乙個packedsequence物件
class model(nn.module):
def __init__(self, in_size, hid_size, n_layer, drop=0.1, bi=false):
super(model, self).__init__()
self.lstm = nn.lstm(input_size=in_size,
hidden_size=hid_size,
num_layers=n_layer,
batch_first=true,
dropout=drop,
bidirectional=bi)
# 分類類別數目為2
self.fc = nn.linear(in_features=hid_size, out_features=2)
def forward(self, x):
''':param x: 變長序列時,x是乙個packedsequence物件
:return: packedsequence物件
'''# lstm_out: tensor of shape (batch, seq_len, num_directions * hidden_size)
lstm_out, _ = self.lstm(x)
return lstm_out
model = model()
lstm_out = model(packed_input)
這個操作和pack_padded_sequence()是相反的,把壓緊的序列再填充回來。因為前面提到的lstm模型傳入和返回的都是packedsequence物件,所以我們如果想要把返回的packedsequence物件轉換回tensor,就需要用到pad_packed_sequence函式。
引數說明:
返回值: 乙個tuple,包含被填充後的序列,和batch中序列的長度列表。
用法:
# 此處lstm_out是乙個packedsequence物件
output, _ = pad_packed_sequence(lstm_out)
返回的output是乙個形狀為(batch_size,seq_len,input_size)的tensor。
pytorch在自定義dataset時,可以在dataloader的collate_fn引數中定義對資料的變換,操作以及合成batch的方式。
處理變長rnn問題時,通過pack_padded_sequence()將填充的batch資料轉換成packedsequence物件,直接傳入rnn模型中。通過pad_packed_sequence()來將rnn模型輸出的packedsequence物件轉換回相應的tensor。
pytorch rnn 變長輸入序列問題
輸入資料是長度不固定的序列資料,主要講解兩個部分 data.dataloader的collate fn用法,以及按batch進行padding資料 pack padded sequence和pad packed sequence來處理變長序列 dataloader的collate fn引數,定義資料...
可變長字串
目錄stringbuilder 其他可變長字串,jdk1.0提供,執行效率慢,執行緒安全字串緩衝區 執行緒安全的可變字串 字串行 字串 如果字串需要頻繁修改,可用stringbuffer構造方法stringbuffer 初始容量為16個字元 stringbuffer int capacity 構造乙...
struct 封裝變長字串
使用struct,可以非常方便的處理二進位制資料,將常用的int,string等型別的資料轉成二進位制資料,它有兩個重要函式,乙個是pack,乙個是unpack 先看一張表 struct中支援的格式如下表 format c type python 位元組數x pad byte no value1c ...