1. pytorch 對於資料的標籤要求是長整形,因此要對標籤進行轉換
train_label = train_label.long()
2.對於資料的特徵部分,要轉換為tensor 形式,可以通過torch.tensor 將資料從numpy 轉為tensor
train_fea = torch.tensor(train_fea, dtype=torch.float32)
3.封裝資料
data = torch.utils.data.tensordataset(train_fea, train_label)#(特徵,標籤)
4.封裝成第三步的形式,就可以採用torch 中的資料載入器為模型提供資料,資料載入器可以自動分批餵給模型資料
train_loader = torch.utils.data.dataloader(data,batch_size=batch_size, shuffle=true, drop_last=true, **kwargs)
pytorch 載入自己的資料集
pytorch 載入自己的資料集,需要寫乙個繼承自torch.utils.data中dataset類,並修改其中的 init 方法 getitem 方法 len 方法。預設載入的都是,init 的目的是得到乙個包含資料和標籤的list,每個元素能找到位置和其對應標籤。然後用 getitem 方法得到...
Pytorch學習 一) 定義自己的資料集
torch.utils.data.dataset 是資料集的抽象類,當我們定義自己的資料集都要繼承這個方法,並且必須覆蓋它的 len 和 getitem 這個兩個方法,len 提高了資料集的大小,getitem 用來索引資料集中每個樣本,import torch.utils.data import ...
pytorch 載入資料集
2 tensor 的 構造方式 import torch import numpy as np data np.array 1,2,3 print torch.tensor data 副本 print torch.tensor data 副本 print torch.as tensor data 檢...