原理:
假設我們有兩個網路:乙個生g(generator),乙個判別d(discriminator)。
g是乙個生成的的網路,它接受乙個隨機的雜訊z,通過這個雜訊生成,記做g(z)。
d是乙個判別網路,判斷一張是不是「真實的」。它的輸入引數是x,x代表一張的。輸出d(x)代表x為真實的概率,如果為1,就代表100%是真實的,輸出為0,就代表不可能是真實的。
我們的目的:得到乙個生成式的模型g,可以用它來生成。
目標函式:
在這裡:
x表示真實,z表示輸入g網路的雜訊,而g(z)表示g網路生成的。
d(x)用來判斷真實的為真實的概率。
d(g(z))是d網路判斷g生成的是否為真實的概率。
g的目的:g希望自己生成的盡可能的接近真實,也就是d(g(z))得到的概率值盡可能的大。這時候是減小v(d,g),因此前面式子標記是min_g。
d的目的:d越強,d(x)應該越大,d(g(x))應該越小,因此對於式子是max_d。
所以兩個網路是對抗不過,隨著網路優化,g生成影象越來越接近真實影象,d的判別能力也越來越高。
我們可以看看lan goodfellow的描述:
(a)為一開始情況,黑線表示影象x的實際資料分布。綠線表示生成的資料分布。我們希望綠色的線能夠趨近於黑色的線。即讓生成的資料盡可能接近真實影象。藍線表示生成的資料對應於d的分布。
一開始情況如a所示,g的能力有限,跟真實影象的資料分布有一定差距,d網路識別真實影象和生成影象雖然有波動但無壓力,能夠區分出來。
訓練一段時間,到了b圖,d訓練的比較好,可以很明顯的區分生成資料,由藍線可以看出隨著黑線和綠線的差異增大,藍線明顯下降了。
訓練一段時間後,g的目標是提公升概率,因此綠線往藍線高的方向移動。
訓練到d圖,隨著g網路的提公升,g反過來影響d的分布d最終會達到0.5,g網路和d網路處於平衡狀態,無法再進一步更新了。
網路實現
演算法:使用小批量隨機梯度下降法來訓練gan。k是訓練判別網路d的更新次數。根據以往經驗,我們設k=1。
for 訓練次數:
for k次:
從雜訊先驗知識 p(z) 中抽樣m個雜訊樣本;
從資料生成分布pdata(x)中抽取小批量m個樣本。
根據以下隨機梯度上公升法更新判別器d
從雜訊先驗知識 p(z) 中抽樣m個雜訊樣本;
根據以下隨機梯度下降法更新生成器g
(1)g和d是同步訓練的,但兩者訓練次數不一樣,g訓練一次,d訓練k次。
(2)d的訓練是同時輸入生成的資料和樣本資料計算loss。(不用交叉熵是因為交叉熵會使d(g(z))變為0,導致沒有梯度,無法更新g。
在實際訓練中,文章中g網路使用了relu和sigmoid,而d網路使用了maxout和dropout。並且文章中作者實際使用-log(d(g(z))來代替log(1-d(g(z)),從而在訓練的開始使可以加大梯度資訊,但是改變後的loss將使整個gan不是乙個完美的零和博弈。
gan的訓練過程可用下圖來描述:
gan的缺點:不容易訓練
生成對抗網路 GAN
原文 generative adversarial networks 模型組成 核心公式 演算法圖示化描述 全域性最優點 pg pdata 效果與對比展望 ming maxdv d,g exp data x logd x exp x x log 1 d g z 分析 上方為 gan 網路的核心演算法...
GAN(生成對抗網路)
gan,generative adversarial network.起源於2014年,nips的一篇文章,generative adversarial net.gan,是一種二人博弈的思想,雙方利益之和是乙個常數,是固定的。你的利益多點,對方利益就少點。gan裡面,博弈雙方是 乙個叫g 生成模型 ...
生成對抗網路 GAN
機器學習中的模型一般有兩種 1.決策函式 y f x 2.條件概率分布 p y x 根據通過學習資料來獲取這兩種模型的方法,可以分為判別方法和生成方法。判別方法是由資料直接學習決策函式或條件概率分布作為 模型,即判別模型 而生成模型是由資料學習聯合概率分布 p x,y 然後由 p y x p x,y...