需要定義diabetesdataset做為載入資料集diabetes的類, 繼承自dataset,dataset是抽象類,需要實現其中的三個方法,__ init,getitem,len __
import torch
from torch.utils.data import dataset # 抽象類
from torch.utils.data import dataloader
import numpy as np
class
diabetesdataset
(dataset)
:# 繼承自dataset
def__init__
(self, filepath)
: xy = np.loadtxt(filepath, delimiter =
',', dtype = np.float32)
self.
len= xy.shape[0]
self.x_data = torch.from_numpy(xy[:,
:-1]
) self.y_data = torch.from_numpy(xy[:,
[-1]
])def__getitem__
(self, index)
:# 支援下標操作,根據索引獲取資料
return self.x_data[index]
, self.y_data[index]
def__len__
(self)
:# 獲取資料條數
return self.
lendataset = diabetesdataset(
'diabetes.csv.gz'
)train_loader = dataloader(dataset = dataset,
# 處理的資料集
batch_size =32,
# 每次處理的資料大小
shuffle =
true
,# 是否打亂
num_workers =0)
# 多執行緒數量,在windows裡需要設定為0, linux可以大於0
class
model
(torch.nn.module)
:def
__init__
(self)
:super
(model, self)
.__init__(
) self.linear1 = torch.nn.linear(8,
6)self.linear2 = torch.nn.linear(6,
4)self.linear3 = torch.nn.linear(4,
1)self.sigmoid = torch.nn.sigmoid(
)# 與nn.function.sigmoid不同,用於構建計算圖
defforward
(self, x)
: x = self.sigmoid(self.linear1(x)
) x = self.sigmoid(self.linear2(x)
) x = self.sigmoid(self.linear3(x)
)return x
model = model(
)criterion = torch.nn.bceloss(reduction=
'mean'
)# 損失函式
optimizer = torch.optim.sgd(model.parameters(
), lr =
0.1)
# 優化器
if __name__ ==
'__main__'
:for epoch in
range
(100):
for i, data in
enumerate
(train_loader,0)
:#1. prepare data
inputs, labels = data
# 2.forward
y_pred = model(inputs)
loss = criterion(y_pred, labels)
print
(epoch, i, loss.item())
# 3.backward
optimizer.zero_grad(
) loss.backward(
)# 4.update
optimizer.step(
)
輸出:
0 0 0.6936783194541931
0 1 0.693471372127533
0 2 0.6917673349380493
0 3 0.6861389875411987
0 4 0.6913132667541504
0 5 0.6789288520812988
0 6 0.6768878698348999
0 7 0.6651645302772522
0 8 0.6861144304275513
0 9 0.6686166524887085
0 10 0.6661809682846069
0 11 0.6636384129524231
0 12 0.6618748307228088
0 13 0.6681938767433167
0 14 0.6153277158737183
0 15 0.6548603773117065
... ...
98 14 0.5910221338272095
98 15 0.6699521541595459
98 16 0.6283824443817139
98 17 0.6495291590690613
98 18 0.6865949630737305
98 19 0.6016601920127869
98 20 0.630635678768158
98 21 0.6044492721557617
98 22 0.6302173137664795
98 23 0.6102578043937683
99 0 0.5284566283226013
99 1 0.6872431039810181
99 2 0.6330350041389465
99 3 0.6103817820549011
99 4 0.6251040697097778
99 5 0.6059320569038391
99 6 0.6281994581222534
99 7 0.6733802556991577
99 8 0.6273549795150757
99 9 0.7067252993583679
99 10 0.6479067802429199
99 11 0.7034580111503601
99 12 0.633543848991394
99 13 0.5920330882072449
99 14 0.6311102509498596
99 15 0.6479007601737976
99 16 0.6280706524848938
99 17 0.6995146870613098
99 18 0.6469420790672302
99 19 0.6414950489997864
99 20 0.5969923734664917
99 21 0.5866757035255432
99 22 0.5923041105270386
99 23 0.524055004119873
pytorch 載入資料集
2 tensor 的 構造方式 import torch import numpy as np data np.array 1,2,3 print torch.tensor data 副本 print torch.tensor data 副本 print torch.as tensor data 檢...
pytorch 載入自己的資料集
pytorch 載入自己的資料集,需要寫乙個繼承自torch.utils.data中dataset類,並修改其中的 init 方法 getitem 方法 len 方法。預設載入的都是,init 的目的是得到乙個包含資料和標籤的list,每個元素能找到位置和其對應標籤。然後用 getitem 方法得到...
pytorch載入資料
參考 pytorch深度學習快速入門教程 絕對通俗易懂!小土堆 可看到說明,dataset是乙個抽象類,我們重寫dataset時要繼承這個類,所有的子類都應該重寫 getitem 方法,這個方法作用是獲取資料及對應的labe。同時我們可以選擇性地去重寫 len 方法,其作用是獲取資料集長度。這裡我使...