Pytorch學習之旅 2 線性回歸實戰

2021-10-05 21:11:57 字數 3617 閱讀 9778

回歸分析中,只包括乙個自變數和乙個因變數,且二者的關係可用一條直線近似表示,這種回歸分析稱為一元線性回歸分析,即 y = w * x + b .

m od

el:y

=w∗x

+bmodel: y = w * x + b

model:

y=w∗

x+b採用最小平方法求解一元線性模型最優引數 w、b。

m se

:1m∑

i=1m

(ype

rd−y

)2mse: \frac\sum_^(y_-y)^2

mse:m1

​i=1

∑m​(

yper

d​−y

)2w =w

−lr∗

w.gr

adw = w - lr * w.grad

w=w−lr

∗w.g

radb=b

−lr∗

b.gr

adb = b - lr*b.grad

b=b−lr

∗b.grad

lr =

0.01

# 學習率

# 建立訓練資料

x = torch.rand(20,

1)*10

# 設定引數 x,x.shape=torch.size([20, 1])

y =2

* x +(5

+ torch.rand(20,

1))# 設定引數 y, y.shape=torch.size([20, 1])

w = torch.randn((1

), requires_grad=

true

)# 設定引數 w,w.shape=torch.size([1])

b = torch.zeros((0

), requires_grad=

true

)# 設定引數 b,b.shape=torch.size([1])

# 前向傳播

wx = torch.mul(w, x)

y_pred = torch.add(wx, b)

# 計算mse.loss

loss =

(0.5

*(y - y_pred)**2

).mean(

)

# 反向傳播

loss.backward(

)

# 更新引數

w.data.sub_(lr * w.grad)

b.data.sub_(lr * b.grad)

# 清零張量的梯度

w.grad.zero_(

)b.grad.zero_(

)

plt.scatter(x.data.numpy(

), y.data.numpy())

plt.plot(x.data.numpy(

), y_pred.data.numpy(),

'r-'

, lw =5)

plt.text(2,

20,'loss = %.4f'

%loss.data.numpy(

), fontdict=

)plt.xlim(

1.5,20)

plt.ylim(8,

28)plt.title(

'iterator:{}\nw:{}\nb:{}'

.format

(iteration, w.data.numpy(

), b.data.numpy())

)plt.pause(

0.5)

import torch

import matplotlib.pyplot as plt

lr =

0.01

# 學習率

# 建立訓練資料

x = torch.rand(20,

1)*10

y =2

* x +(5

+ torch.rand(20,

1))w = torch.randn((1

), requires_grad=

true

)b = torch.zeros((1

), requires_grad=

true

)for iteration in

range

(1000):

# 前向傳播

wx = torch.mul(w, x)

y_pred = torch.add(wx, b)

# 計算mse.loss

loss =

(0.5

*(y - y_pred)**2

).mean(

)# 反向傳播

loss.backward(

)# 更新引數

w.data.sub_(lr * w.grad)

b.data.sub_(lr * b.grad)

# 清零張量的梯度

w.grad.zero_(

) b.grad.zero_(

)# 繪圖

1、改變學習率 lr,可以改變梯度下降速率,加快擬合速度。

2、當資料過於分散時,計算的熵值可能達不到預定值以下,所以要根據實際情況調整預定值。

pytorch學習之旅 1

torch.tensor是torch.tensor與torch.empty的一種混合。當傳入資料時,torch.tensor使用全域性預設的dtype floattensor,而torch.tensor從資料中推斷資料型別。import torch t1 torch.tensor 2 3 t2 to...

Pytorch學習之旅 4 邏輯回歸實戰

資料處理 建立模型 選擇損失函式 選擇優化器 迭代訓練import torch import torch.nn as nn import matplotlib.pyplot as plt import numpy as np torch.manual seed 10 step 1 5 生成資料 sa...

pytorch學習筆記3 線性回歸

線性回歸線性回歸 是分析乙個變數與另外乙個 多 個變數之間關係的方法 因變數 y 自變數 x 關係 線性 y wx b 求解w,b 求解步驟 1.確定模型 module y wx b 2.選擇損失函式 mse 均方差等 3.求解梯度並更新w,b w w lr w.grad b b lr w.grad...