pytorch長度不同的資料如何放在乙個batch

2021-10-09 07:38:47 字數 1850 閱讀 8755

rnn及其變種演算法處理一維訊號經常會遇到訊號長度不一致的問題。

from torch.utils.data import dataloader

dataloader = dataloader(dataset, batch_size=

8)

這樣是沒法成功載入dataset,因為dataloader要求乙個batch內的資料shape是一致的,才能打包成乙個方塊投入模型。我們看一下原始碼裡dataloader初始化的方法

def

__init__

(self, dataset, batch_size=

1, shuffle=

false

, sampler=

none

, batch_sampler=

none

, num_workers=

0, collate_fn=

none

, pin_memory=

false

, drop_last=

false

, timeout=

0,

其中collate_fn是pytorch為我們提供的資料裁剪函式,當collate_fn=none時,初始化會呼叫預設的裁剪方式即直接將資料打包,所以這時如果資料shape不一致,會打包不成功。因此我們需要自己寫乙個collate_fn函式,我常用的兩種方式是:1.將所有資料截斷到和最短的資料一樣長;2.將所有的資料補零到和最長的資料一樣長。

這裡給出第一種方法的實現方式,第二種稍微複雜一點,也不難:

import torch

defcollate_fn

(data)

:# 這裡的data是乙個list, list的元素是元組,元組構成為(self.data, self.label)

# collate_fn的作用是把[(data, label),(data, label)...]轉化成([data, data...],[label,label...])

# 假設self.data的乙個data的shape為(channels, length), 每乙個channel的length相等,data[索引到資料index][索引到data或者label][索引到channel]

data.sort(key=

lambda x:

len(x[0]

[0])

, reverse=

false

)# 按照資料長度公升序排序

data_list =

label_list =

min_len =

len(data[0]

[0][

0])# 最短的資料長度

for batch in

range(0

,len

(data)):

#[0]

[:,:min_len])[

1]) data_tensor = torch.tensor(data_list, dtype=torch.float32)

label_tensor = torch.tensor(label_list, dtype=torch.float32)

data_copy =

(data_tensor, label_tensor)

return data_copy

使用方法也很簡單

dataloader = torch.utils.data.dataloader(dataset, collate_fn=collate_fn)

pytorch模型運算時間的不同

今天測試模型計算時間時,執行了這樣一段 在呼叫時發現 首次呼叫第乙個模型的計算時間 會比之後的模型計算時間長很多,執行結果如圖 可以看到即便是同樣是resnet50,第一次呼叫和第二次呼叫時間差別非常大。而且即便我改變了模型載入的順序,在呼叫時仍然是第乙個模型的計算時間遠超其他模型。因此我判斷pyt...

pytorch 實現模型不同層設定不同的學習率方式

在目標檢測的模型訓練中,我們通常都會有乙個特徵提取網路backbone,例如yolo使用的darknet ssd使用的vgg 16。為了達到程式設計客棧比較好的訓練效果,往往會載入預訓練的backbone模型引數,然後在此基礎上訓練檢測網路,並對backbone進行微調,這時候就需要為backbon...

C語言 不同資料型別長度獲取問題

我們能常在用到 sizeof 和 strlen 的時候,通常是計算字串陣列的長度,c語言中有乙個可以獲取字串長度的函式strlen並且與sizeof做對比 extern unsigned int strlen char s 其中形參只能為字元指標型別,其從給定變數的第乙個位置開始掃瞄,直到遇到 0 ...