Task4 用PyTorch實現多層網路

2022-06-13 17:18:15 字數 2532 閱讀 9001

1.引入模組,讀取資料

2.構建計算圖(構建網路模型)

3.損失函式與優化器

4.開始訓練模型

5.對訓練的模型**結果進行評估

1

import

torch.nn.functional as f

2import

torch.nn.init as init

3import

torch

4from torch.autograd import

variable

5import

matplotlib.pyplot as plt

6import

numpy as np

7import

math

8 %matplotlib inline9#

%matplotlib inline 可以在ipython編譯器裡直接使用10#

功能是可以內嵌繪圖,並且可以省略掉plt.show()這一步。

1112 xy=np.loadtxt('

./data/diabetes.csv.gz

',delimiter='

,',dtype=np.float32)

13 x_data=torch.from_numpy(xy[:,0:-1])#

取除了最後一列的資料

14 y_data=torch.from_numpy(xy[:,[-1]])#

取最後一列的資料,[-1]加中括號是為了keepdim

1516

print

(x_data.size(),y_data.size())17#

print(x_data.shape,y_data.shape)

1819

#建立網路模型

20class

model(torch.nn.module):

2122

def__init__

(self):

23 super(model,self).__init__

()24 self.l1=torch.nn.linear(8,6)

25 self.l2=torch.nn.linear(6,4)

26 self.l3=torch.nn.linear(4,1)

2728

defforward(self,x):

29 out1=f.relu(self.l1(x))

30 out2=f.dropout(out1,p=0.5)

31 out3=f.relu(self.l2(out2))

32 out4=f.dropout(out3,p=0.5)

33 y_pred=f.sigmoid(self.l3(out3))

34return

y_pred

3536

defweights_init(m):

37 classname=m.__class__.__name__

38if classname.find('

linear

')!=-1:

39 m.weight.data=torch.randn(m.weight.data.size()[0],m.weight.data.size()[1])

40 m.bias.data=torch.randn(m.bias.data.size()[0])

4142

#our model

43 model=model()

4445 criterion=torch.nn.bceloss()

46 optimizer=torch.optim.sgd(model.parameters(),lr=0.1)

4748

#training loop

49 loss=

50for epoch in range(2000):

51 y_pred=model(x_data)

52 loss=criterion(y_pred,y_data)

53if epoch%100 ==0:

54print("

epoch =

",epoch,"

loss =

",loss.data)

5556

optimizer.zero_grad()

57loss.backward()

58optimizer.step()

5960 hour_var = variable(torch.randn(1,8))

61print("

predict

",model(hour_var).data[0]>0.5)

62 plt.plot(loss)

動手學深度學習PyTorch版 task4

目錄 task1 task2 task3 task4 本章節 task5 task6 task8 task9 task10 1 機器翻譯及相關技術 機器翻譯 mt 將一段文字從一種語言自動翻譯為另一種語言,用神經網路解決這個問題通常稱為神經機器翻譯 nmt 主要特徵 輸出是單詞序列而不是單個單詞。輸...

學習筆記Task4

趕專案進度,僅了解 神經網路是由具有適應性的簡單單元所組成的廣泛並行互連的網路,它的組織能夠模擬生物神經系統對真實世界所做出的的互動反應。神經網路最基本的成分是神經元模型,當通過神經元的資訊信好超過某乙個閾值,那麼該神經元就會啟用,從而作用於下乙個神經元。在m p神經元模型中,神經元接收到來自n個其...

Task4 三數之和

給定乙個包含 n 個整數的陣列 nums,判斷 nums 中是否存在三個元素 a,b,c 使得 a b c 0 找出所有滿足條件且不重複的三元組。注意 答案中不可以包含重複的三元組。class solution def threesum self,nums list int list list in...