梯度下降演算法是乙個很基本的演算法,在機器學習和優化中有著非常重要的作用,本文首先介紹了梯度下降的基本概念,然後使用python實現了乙個基本的梯度下降演算法。梯度下降有很多的變種,本文只介紹最基礎的梯度下降,也就是批梯度下降。
實際應用例子就不詳細說了,網上關於梯度下降的應用例子很多,最多的就是ng課上的**房價例子:
假設有乙個房屋銷售的資料如下:
面積(m^2) 銷售價錢(萬元)
面積(m^2)
銷售價錢(萬元)
123250
150320
87180
根據上面的房價我們可以做這樣乙個圖:
於是我們的目標就是去擬合這個圖,使得新的樣本資料進來以後我們可以方便進行**:
對於最基本的線性回歸問題,公式如下:
x是自變數,比如說房子面積。θ是權重引數,也就是我們需要去梯度下降求解的具體值。
在這兒,我們需要引入損失函式(loss function 或者叫 cost function),目的是為了在梯度下降時用來衡量我們更新後的引數是否是向著正確的方向前進,如圖損失函式(m表示訓練集樣本數量):
下圖直觀顯示了我們梯度下降的方向,就是希望從最高處一直下降到最低出:
梯度下降更新權重引數的過程中我們需要對損失函式求偏導數:
求完偏導數以後就可以進行引數更新了:
偽**如圖所示:
好了,下面到了**實現環節,我們用python來實現乙個梯度下降演算法,求解:y=
2x1+
x2+3
,也就是求解:y=
ax1+
bx2+
c 中的a,b,c三個引數 。
下面是**:
import numpy as np
import matplotlib.pyplot as plt
#y=2 * (x1) + (x2) + 3
rate = 0.001
x_train = np.array([ [1, 2], [2, 1], [2, 3], [3, 5], [1, 3], [4, 2], [7, 3], [4, 5], [11, 3], [8, 7] ])
y_train = np.array([7, 8, 10, 14, 8, 13, 20, 16, 28, 26])
x_test = np.array([ [1, 4], [2, 2], [2, 5], [5, 3], [1, 5], [4, 1] ])
a = np.random.normal()
b = np.random.normal()
c = np.random.normal()
defh
(x):
return a*x[0]+b*x[1]+c
for i in range(10000):
sum_a=0
sum_b=0
sum_c=0
for x, y in zip(x_train, y_train):
sum_a = sum_a + rate*(y-h(x))*x[0]
sum_b = sum_b + rate*(y-h(x))*x[1]
sum_c = sum_c + rate*(y-h(x))
a = a + sum_a
b = b + sum_b
c = c + sum_c
plt.plot([h(xi) for xi in x_test])
print(a)
print(b)
print(c)
result=[h(xi) for xi in x_train]
print(result)
result=[h(xi) for xi in x_test]
print(result)
plt.show()
x_train是訓練集x,y_train是訓練集y, x_test是測試集x,執行後得到如下的圖,顯示了演算法對於測試集y的**在每一輪迭代中是如何變化的:
我們可以看到,線段是在逐漸逼近的,訓練資料越多,迭代次數越多就越逼近真實值。
參考文章:
梯度下降演算法及Python實現
梯度下降是乙個用來求函式最小值的演算法,其背後的思想是 開始時我們隨機選擇乙個引數的組合,計算代價函式,然後我們尋找下乙個能讓代價函式值下降最多的引數組合。我們持續這麼做直到到達乙個區域性最小值,因為我們並沒有嘗試完所有的引數組合,所以不能確定我們得到的區域性最小值是否便是全域性最小值,選擇不同的初...
梯度下降原理
在不一定可解的情況下,不用公式來求最小值,而是通過嘗試。線性回歸是通過公式來求解。梯度為求偏導,偏導值為梯度。下降為偏導值的下降方向。常規套路 機器學習的套路就是交給機器一堆資料,然後告訴它怎樣的方向是對的 目標函式 然後它朝著這個方向去做。梯度下降就是求 什麼樣的引數能是目標函式達到最小值點。目標...
梯度下降演算法的簡單Python原理實現
梯度下降 導數值下降 import matplotlib.pyplot as plt import numpy as np 目標函式 f x x 2 梯度函式 一階導數函式 f x 2 x 梯度下降演算法是乙個方法,是幫助我們找極值點的方法cost 梯度下降 導數值下降 import matplot...