用梯度下降演算法訓練神經網路的時候,求導過程是其中的關鍵計算之一。使用tensorflow的使用者會發現,神經網路的反向傳播計算是使用者不用考慮的,在給足便捷性的同時也抑制了使用者對反向傳播的探索心態(博主深受其害)。tensorflow同時也激起了乙個思考:一定存在某種求導的通用方法。
這篇文章主要探索程式設計求導的通法。
方法1: 定義導函式:
對於函式:y=
x2y =x
2其導數為:y′
=2∗x
y ′=
2∗
x對於函式:y=
sin(
x)y =s
in(x
)其導數為:y′
=cos
(x) y′=
cos(
x)
這裡的導函式都需要開發者自己用筆算出來,然後用**表示出來。為什麼可以這麼做?因為神經網路裡用到的函式種類有限,我們可以把所有的函式導函式都預先定義出來。如下:
# 目標函式:y=x^2
deffunc
(x):
return x**2
# 目標函式一階導數:dy/dx=2*x
defdfunc
(x):
return
2*x
來看這篇文章的,估計不是為了得到這個答案,為的是解決一般情況的求導公式。而且對於千變萬化的loss function,每次都自己手動求導就很麻煩了。
csdn上有不少講解求導的文章,有碼友說通過把各種常見基本函式的導數定義好,然後對任何複雜函式進行分解,再查表呼叫基本函式的導函式,例如y=
sin(
x)+x
2 y=s
in(x
)+x2
拆分成兩個基本函式y1
y
1=si
n(x)
s in
(x
)和y2
= y2=
x2 x
2。事實上,這種做法也很難,分解複雜函式是最大的難點。比如y=
sin(
x2/x
) y=s
in(x
2/x)
,這要怎麼通過一般式分解呢,當然可以分解,不過分複雜。博主認為,這種思路明顯把這個問題弄複雜了。
這種方法,歸根結底還是屬於查表法。
是否認真想過,求導計算能不能用非查表法解決?
方法2:使用專門求導模組
python提供了乙個十分好用的求導模組sympy,當然並不是唯一能用來求導的模組。
這個求導方法如下:
from sympy import *
x = symbol("x") # 把"x"設定為自變數
fu = diff(x ** 2, x) # 求導x^2
print(fu)
輸出:
2
*x
可以試著把原函式寫複雜點: y=
sin(
x2/x
) y=s
in(x
2/x)
改為fu = diff(sin(x**2)/x, x)
輸出為:
2*cos(x*
*2) - sin(x*
*2)/x**2
可以發現:這個方法可以求解複雜函式的導數。
其實這個方法還有若干缺陷:輸入的x只能是數,不能是numpy陣列。我們知道,當進行反向傳播的時候,輸入往往都是陣列。此外,很多函式並不是用函式表示式來表示的(下面會舉例),沒有確切的函式表示式是無法使用這個方法來求導的。
# 目標函式:y=x^2
deffunc
(x):
return np.square(x)
如上,這個函式沒有確切的函式表示式(雖然知道是y=
x2y =x
2,但無法用這種辦法求導)
如果嚮導函式輸入數字:
from sympy import *
x = symbol("x")
fu = diff(sin(x**2) / x, x)
print(fu.subs('x', 3))
輸出: 6
如果嚮導函式輸入陣列:
from sympy import *
a = np.array([[1,3,5,7],[7,7,0,0]])
x = symbol("x")
fu = diff(sin(x**2) / x, x)
print(fu.subs('x', a))
輸出: 2∗
x 2∗x
可見,這種方法無法把陣列當作輸入。只能把陣列的每乙個元素抽出來作為單獨的數輸入到導函式中求解,這種操作非常慢(矩陣計算可以用gpu並行執行,而轉換成單獨的數字就不能利用gpu的並行性了)。
方法3. 通過導數的數學定義求導
方法2可以讓高數初學者驗證自己求導是否正確。方法3可以在實際過程中解決問題:計算速度和精度都達到不錯的程度。這就是所謂的「無敵」方法。導數定義如下: f′
(x)=
limδx→
0[f(
x+δx
)−f(
x)]/
δxf ′(
x)
=limδx
→0[f
(x+δ
x)−f
(x)]
/δ
x直接用乙個很小的數來代替δx
δ
x就可以在計算機上實現這個求導公式了。
實現**如下:
def
derive
(f):
dx = 0.0000000001
return
lambda x: (f(x + dx) - f(x)) / dx
對,這種方法並不能給你返回乙個導函式公式,你輸入x2
x
2並不能返回2∗
x 2∗x
。但在計算機系統中,求導計算並不需要你準確的導函式公式,只要導函式函式值的計算精度高便可以使用。這種方法,尤其在神經網路的方向傳播中十分有用!
具體怎麼使用呢,我們可以改造一下上面的例子:
# 目標函式:y=x^2
deffunc
(x):
return np.square(x)
# 目標函式一階導數:dy/dx=2*x
defdfunc
(x):
return derive(func)(x)
defderive
(f):
dx = 0.0000000001
return
lambda x: (f(x + dx) - f(x)) / dx
可以看出這裡的derive方法可以通用。但是呢,這種方法有點點缺陷說明一下:滿足數學意義上具有可導性的函式才能用這種方法求解,比如relu(x)函式,它在x=0處不可導,那麼我們需要對它人工分段,不然可能會有尖刺。雖然對最後結果影響幾乎沒有,不過注意一下還是好的。 RSA 數學原理
提起rsa大家一定不陌生,在開發中經常使用,也經常聽同事說道。話說很久以前,人們就懂的了加密這個技術。在戰爭時期,間諜就會拿著密文和密匙來對資訊就行傳遞。這種簡單的密文 密匙 key 就是對稱加密 加密 明文 密匙 解密 密文 密匙 由於這種加密方式過於簡單,所以後來引入了數學演算法。rsa就是由特...
數學,原理,方法,技巧
學校裡學到的東西為什麼沒有用處?主要是學到的東西大部份都有人去實現了。比如資料結構中學做乙個二叉樹。其實在外邊幹活的時候根本不需要。誰會讓你去編寫乙個二叉樹?即使做專案的時候真的有用,大概也有已經實現的類代替了。所以有人說學校裡學到的東西基本沒有用。不過也難怪,因為這些人只是學到了原理或者方法。學校...
補碼的數學原理
計算機是用n位0和1來表示數字的,這樣很容易表示正數,但是怎麼表示負數呢?人類聰明的大腦想到了用第一位來表示符號,0代表正數,1代表負數。這種表示方法最好理解,叫做原碼。但是計算機在計算的時候,為了簡化,需要把減法當做加法運算。這個很簡單,負數不就是幹這個的嗎?比如2 1 2 1 但是負數如果按照原...