Pytorch入門之線性回歸

2021-10-10 01:16:51 字數 2336 閱讀 7211

這裡定義乙個簡單的神經網路來做乙個線性回歸問題

神經元之間的線就不連了,大家知道是個全連線層就好

搭建這樣乙個網路,首先就是需要定義乙個class,class必須得繼承nn.module類,常用來被繼承,然後使用者去編寫自己的網路/層。

類中的初始化部分需要去例化自己的層。這裡需要定義2個全連線層,因此我們可以直接呼叫nn.linear這個類,關於這個類不清楚的,可以看看我的另一篇文章:

self.hidden_layers = nn.linear(feature_in, feature_hidden)

self.predict_layers = nn.linear(feature_hidden, feature_out)

類中的前向傳播函式需要去模擬前向傳播的過程,即輸入經過全連線層然後通過激勵函式輸出,接著又通過輸出端的全連線層繼續輸出成最終的**值

out0 = self.hidden_layers(datain).relu()

out1 = self.predict_layers(out0)

這樣一來,我們的網路就編寫好了。

接下來,就是去生成這個網路,就是例化這個類。接著定義乙個有序容器squential,將上面這個網路依次裝進去

net = net(1, 1, 10)

network = nn.sequential()

network.add_module('full0', net.hidden_layers)

network.add_module('dull1', net.predict_layers)

我們可以列印一下這個網路

很清晰的顯示了這2個全連線層。

網路生成好了,接下來就是對網路引數進行初始化

我們匯入torch.nn中的init模組

# 初始化引數w b

init.normal_(network[0].weight, 0, 0.01) # 服從均值0,方差0.01的正太分布

init.normal_(network[1].weight, 0, 0.02) # 服從均值0,方差0.01的正太分布

init.constant_(network[0].bias, 0) # 填充函式

init.constant_(network[1].bias, 0) # 填充函式

我們可以列印一下這個引數看看,net.parameters(),用來檢視網路net中的引數,全連線層中的引數就是(w, b)

1、定義損失函式,這裡我們就選均方誤差

loss = nn.mseloss()
2、生成優化器 ,這裡我們選用sgd

optimizer = optim.sgd(network.parameters(), lr=0.2)  #例化物件
有關sgd的詳細介紹在我的另一篇文章中,有興趣的可以看看:

在提前定義好的epochs下,進行前向傳播,計算loss,然後清空梯度,反向傳播計算梯度,sgd優化損失函式。寫成**就是:

for epoch in range(epochs):

out_forward = network(y)

target_func = loss(out_forward, y)

optimizer.zero_grad() # 或者network.zero_grad()

target_func.backward()

optimizer.step()

這一樣一來,整個網路就訓練好了

這是最後在訓練集上擬合的情況:

總結一下,整個學習任務的四步驟:

1、網路類的編寫及生成網路

2、初始化模型引數

3、訓練前的準備,即定義損失函式、生成優化器

4、前向推理、反向傳播、訓練 

Pytorch之線性回歸

import torch from torch import nn import numpy as np import torch.utils.data as data from torch.nn import init 使得模型的可復現性 torch.manual seed 1 設定預設的資料格式...

pytorch線性回歸

線性回歸假設輸出與各個輸入之間是線性關係 y wx b pytorch定義線性回歸模型 def linreg x,w,b return torch.mm x,w b線性回歸常用損失函式是平方損失 2優化函式 隨機梯度下降 小批量隨機梯度下降 mini batch stochastic gradien...

pytorch 線性回歸

import torch from torch.autograd import variable import torch.nn.functional as f import matplotlib.pyplot as plt print torch.linspace 1,1,100 x torch....