封裝自己的pytorch資料集

2021-10-08 13:02:17 字數 576 閱讀 4131

1. pytorch 對於資料的標籤要求是長整形,因此要對標籤進行轉換

train_label = train_label.long()
2.對於資料的特徵部分,要轉換為tensor 形式,可以通過torch.tensor 將資料從numpy 轉為tensor

train_fea = torch.tensor(train_fea, dtype=torch.float32)
3.封裝資料

data = torch.utils.data.tensordataset(train_fea, train_label)#(特徵,標籤)
4.封裝成第三步的形式,就可以採用torch 中的資料載入器為模型提供資料,資料載入器可以自動分批餵給模型資料

train_loader = torch.utils.data.dataloader(data,batch_size=batch_size, shuffle=true, drop_last=true, **kwargs)

pytorch 載入自己的資料集

pytorch 載入自己的資料集,需要寫乙個繼承自torch.utils.data中dataset類,並修改其中的 init 方法 getitem 方法 len 方法。預設載入的都是,init 的目的是得到乙個包含資料和標籤的list,每個元素能找到位置和其對應標籤。然後用 getitem 方法得到...

Pytorch學習 一) 定義自己的資料集

torch.utils.data.dataset 是資料集的抽象類,當我們定義自己的資料集都要繼承這個方法,並且必須覆蓋它的 len 和 getitem 這個兩個方法,len 提高了資料集的大小,getitem 用來索引資料集中每個樣本,import torch.utils.data import ...

pytorch 載入資料集

2 tensor 的 構造方式 import torch import numpy as np data np.array 1,2,3 print torch.tensor data 副本 print torch.tensor data 副本 print torch.as tensor data 檢...