GAN學習系列2 GAN的起源

2021-09-11 11:44:06 字數 3272 閱讀 3030

本文大約 5000 字,閱讀大約需要 10 分鐘

這是 gan 學習系列的第二篇文章,這篇文章將開始介紹 gan 的起源之作,鼻祖,也就是 ian goodfellow 在 2014 年發表在 iclr 的**--generative adversarial networks」,當然由於數學功底有限,所以會簡單介紹用到的數學公式和背後的基本原理,並介紹相應的優缺點。

基本原理

在[gan學習系列] 初識gan中,介紹了 gan 背後的基本思想就是兩個網路彼此博弈。生成器 g 的目標是可以學習到輸入資料的分布從而生成非常真實的,而判別器 d 的目標是可以正確辨別出真實和 g 生成的之間的差異。正如下圖所示:

上圖給出了生成對抗網路的乙個整體結構,生成器 g 和判別器 d 都是有各自的網路結構和不同的輸入,其中 g 的輸出,即生成的樣本也是 d 的輸入之一,而 d 則會為 g 提供梯度進行權重的更新。

那麼問題來了,如果 d 是乙個非常好的分類器,那麼我們是否真的可以生成非常逼真的樣本來欺騙它呢?

對抗樣本

在正式介紹 gan 的原理之前,先介紹乙個概念--對抗樣本(adversarial example),它是指經過精心計算得到的用於誤導分類器的樣本。例如下圖就是乙個例子,左邊是乙個熊貓,但是新增了少量隨機雜訊變成右圖後,分類器給出的**類別卻是長臂猿,但視覺上左右兩幅並沒有太大改變。

所以為什麼在簡單新增了雜訊後會誤導分類器呢?

這是因為影象分類器本質上是高維空間的乙個複雜的決策邊界。當然涉及到影象分類的時候,由於是高維空間而不是簡單的兩維或者三維空間,我們無法畫出這個邊界出來。但是我們可以肯定的是,訓練完成後,分類器是無法泛化到所有資料上,除非我們的訓練集包含了分類類別的所有資料,但實際上我們做不到。而做不到泛化到所有資料的分類器,其實就會過擬合訓練集的資料,這也就是我們可以利用的一點。

我們可以給新增乙個非常接近於 0 的隨機雜訊,這可以通過控制雜訊的 l2 範數來實現。l2 範數可以看做是乙個向量的長度,這裡有個訣竅就是的畫素越多,即尺寸越大,其平均 l2 範數也就越大。因此,當新增的雜訊的範數足夠低,那麼視覺上你不會覺得這張有什麼不同,正如上述右邊的一樣,看起來依然和左邊原始一模一樣;但是,在向量空間上,新增雜訊後的和原始已經有很大的距離了!

為什麼會這樣呢?

因為在 l2 範數看來,對於熊貓和長臂猿的決策邊界並沒有那麼遠,新增了非常微弱的隨機雜訊的可能就遠離了熊貓的決策邊界內,到達長臂猿的**範圍內,因此欺騙了分類器。

除了這種簡單的新增隨機雜訊,還可以通過影象變形的方式,使得新影象和原始影象視覺上一樣的情況下,讓分類器得到有很高置信度的錯誤分類結果。這種過程也被稱為對抗攻擊(adversarial attack),這種生成方式的簡單性也是給 gan 提供了解釋。

生成器和判別器

現在如果將上述說的分類器設定為二值分類器,即判斷真和假,那麼根據 ian goodfellow 的原始**的說法,它就是判別器(discriminator)。

有了判別器,那還需要有生成假樣本來欺騙判別器的網路,也就是生成器 (generator)。這兩個網路結合起來就是生成對抗網路(gan),根據原始**,它的目標如下:

兩個網路的工作原理可以如下圖所示,d 的目標就是判別真實和 g 生成的的真假,而 g 是輸入乙個隨機雜訊來生成,並努力欺騙 d 。

簡單來說,gan 的基本思想就是乙個最小最大定理,當兩個玩家(d 和 g)彼此競爭時(零和博弈),雙方都假設對方採取最優的步驟而自己也以最優的策略應對(最小最大策略),那麼結果就已經預先確定了,玩家無法改變它(納什均衡)。

因此,它們的損失函式,d 的是

g 的是

這裡根據它們的損失函式分析下,g 網路的訓練目標就是讓d(g(z)) 趨近於 1,這也是讓其 loss 變小的做法;而 d 網路的訓練目標是區分真假資料,自然是讓 d(x) 趨近於 1,而 d(g(z)) 趨近於 0。這就是兩個網路相互對抗,彼此博弈的過程了。

那麼,它們相互對抗的效果是怎樣的呢?在**中 ian goodfellow 用下圖來描述這個過程:

上圖中,黑色曲線表示輸入資料 x 的實際分布,綠色曲線表示的是 g 網路生成資料的分布,我們的目標自然是希望著兩條曲線可以相互重合,也就是兩個資料分布一致了。而藍色的曲線表示的是生成資料對應於 d 的分布。

在 a 圖中是剛開始訓練的時候,d 的分類能力還不是最好,因此有所波動,而生成資料的分布也自然和真實資料分布不同,畢竟 g 網路輸入是隨機生成的雜訊;到了 b 圖的時候,d 網路的分類能力就比較好了,可以看到對於真實資料和生成資料,它是明顯可以區分出來,也就是給出的概率是不同的;

而綠色的曲線,即 g 網路的目標是學習真實資料的分布,所以它會往藍色曲線方向移動,也就是 c 圖了,並且因為 g 和 d 是相互對抗的,當 g 網路提公升,也會影響 d 網路的分辨能力。**中,ian goodfellow 做出了證明,當假設 g 網路不變,訓練 d 網路,最優的情況會是:

也就是當生成資料的分布,也就是 d 圖的結果,這也是最終希望達到的訓練結果,這時候 g 和 d 網路也就達到乙個平衡狀態。

訓練策略和演算法實現

**給出的演算法實現過程如下所示:

這裡包含了一些訓練的技巧和方法:

首先 g 和 d 是同步訓練,但兩者訓練次數不一樣,通常是d 網路訓練 k 次後,g 訓練一次。主要原因是 gan 剛開始訓練時候會很不穩定;

d 的訓練是同時輸入真實資料和生成資料來計算 loss,而不是採用交叉熵(cross entropy)分開計算。不採用 cross entropy 的原因是這會讓 d(g(z)) 變為 0,導致沒有梯度提供給 g 更新,而現在 gan 的做法是會收斂到 0.5;

實際訓練的時候,作者是採用

分析優點

gan 在巧妙設計了目標函式後,它就擁有以下兩個優點。

缺點 當然,上述的問題在最近兩年各種 gan 變體中逐漸得到解決方法,比如對於訓練太自由的,出現了 cgan,即提供了一些條件資訊給 g 網路,比如類別標籤等資訊;對於 loss 問題,也出現如 wgan 等設計新的 loss 來解決這個問題。後續會繼續介紹不同的 gan 的變體,它們在不同方面改進原始 gan 的問題,並且也應用在多個方面。

參考文章:

配圖來自網路和** generative adversarial networks

推薦閱讀

1.機器學習入門系列(1)--機器學習概覽(上)

2.機器學習入門系列(2)--機器學習概覽(下)

3.[gan學習系列] 初識gan

GAN的學習筆記(1)

1.本文是我的新手作,主要是記錄一些學習gan 生成式對抗網路 的過程和心得體會,能夠提供一些學習的動力,各位看官能看就看。用到的是系統和相關軟體是win7 anaconda3 tensorflow gpu 1.8.0,python版本是3.5.5。至於為什麼不用ubuntu系統,emm主要是怕折騰...

GAN的調研和學習

近期集中學習了gan,下面記錄一下調研的結果,和學習的心得,疏漏的地方,敬請指正。本文將分為幾個部分進行介紹,首先是gan的由來,其次是gan的發展,最後是gan的應用。先把最近收集的資料列舉一下吧。其中首推知乎的一位博士生,講解的深入淺出,將來也是出好產品的科研人啊。令人拍案叫絕的wasserst...

docker學習系列2 儲存對容器的修改

接上篇 docker容器雖然執行起來了。但遇到了新的問題 容器內安裝的伺服器是nginx,nginx對 phpinfo 支援不好,對於thiankphp專案,簡單的說在apache伺服器下執行 http localhost 8088 home index index 能正常返回結果,而nginx返回...