之前介紹了生成式對抗網路(gan),關於gan的變種比較多,我打算將幾種常見的gan做乙個總結,也算是激勵自己學習,分享自己的一些看法和見解。
之前提到的gan是最基本的模型,我們的輸入是隨機雜訊,輸出的是對應的影象,但是我們沒法控制生成影象的型別。比如,我要生成一張數字0的,但是gan生成的卻是數字0-9的,針對這個問題,conditional generative adversarial nets被提了出來,在原有gan的基礎上,新增了類別資訊以便讓模型生成特定的。這裡的條件(conditional),就是這個額外的類別資訊。
由於在gan的生成器和判別器中都加入了額外的類別資訊,模型的目標優化函式也發生了變化。
生成器的輸入變為噪音變數
就是在gan的目標函式上新增了y這一類別變數,x變為了條件分布。
模型的結構圖如下,
gan的結構與這個類似,生成器部分和判別器部分是分開的兩個子網路,單獨進行訓練。類別資訊y是通過embedding層嵌入的。
具體的實現可以看看**:
生成器:
def build_generator(self):
model = sequential()
model.add(dense(256, input_dim=self.latent_dim))
model.add(leakyrelu(alpha=0.2))
model.add(batchnormalization(momentum=0.8))
model.add(dense(512))
model.add(leakyrelu(alpha=0.2))
model.add(batchnormalization(momentum=0.8))
model.add(dense(1024))
model.add(leakyrelu(alpha=0.2))
model.add(batchnormalization(momentum=0.8))
model.add(dense(np.prod(self.img_shape), activation='tanh'))
model.add(reshape(self.img_shape))
model.summary()
noise = input(shape=(self.latent_dim,))
label = input(shape=(1,), dtype='int32')
label_embedding = flatten()(embedding(self.num_classes, self.latent_dim)(label))
model_input = multiply([noise, label_embedding])
img = model(model_input)
return model([noise, label], img)
標籤是通過嵌入層實現的,embedding層可以將類別標籤轉換為對應的向量表示,在此生成器中,類別有10個(0-9),對應embedding中的input_dim, 輸出維度和噪音資料是相同的,之後,再利用multiply層將兩者逐項做乘積,這便是生成器的輸入。
判別器:
def build_discriminator(self):
model = sequential()
model.add(dense(512, input_dim=np.prod(self.img_shape)))
model.add(leakyrelu(alpha=0.2))
model.add(dense(512))
model.add(leakyrelu(alpha=0.2))
model.add(dropout(0.4))
model.add(dense(512))
model.add(leakyrelu(alpha=0.2))
model.add(dropout(0.4))
model.add(dense(1, activation='sigmoid'))
model.summary()
img = input(shape=self.img_shape)
label = input(shape=(1,), dtype='int32')
label_embedding = flatten()(embedding(self.num_classes, np.prod(self.img_shape))(label))
flat_img = flatten()(img)
model_input = multiply([flat_img, label_embedding])
validity = model(model_input)
return model([img, label], validity)
判別器的輸入和生成器是一樣的,輸出是對應的的類別。
訓練:訓練採用的mnist資料集,訓練時需要將資料和對應的標籤輸入模型。
生成器和判決器作為乙個整體進行訓練的時候,判別器是不訓練的,這時只訓練生成器;當判決器作為乙個單獨的模型時,判決器會得到訓練。二者的訓練是交替進行的。
具體的**可以參考github
最後跑出來的效果還是很不錯的,我在台式電腦上跑的,用的是1050ti的顯示卡,訓練速度還比較快,一共20000輪,大概10分鐘左右跑完。
這是最後的訓練效果:
可以與前一篇部落格裡面的內容進行比較,與原始的gan相比,效果要好一些,但是還是不是很清晰。一方面,mnist提供的畫素較低,另一方面,我們採用的是全連線神經網路,對於的處理效果並不是很好。
要生成更加清晰地,可以利用dcgan,這也是我接下來要做的工作。
生成對抗網路 二 cGAN
cgan conditional gan 也是最基礎的gan模型,和gan原文同時發表在nips2014上面。事實上,cgan在gan的基礎上並沒有做很大的改動,下文會主要分析一下cgan的改動。conditional generative adversarial nets 在訓練判別器d的時候,給...
半監督生成對抗網路 生成對抗網路
一 生成對抗網路相關概念 一 生成模型在概率統計理論中,生成模型是指能夠在給定某些隱含引數的條件下,隨機生成觀測資料的模型,它給觀測值和標註資料序列指定乙個聯合概率分布。在機器學習中,生成模型可以用來直接對資料建模,也可以用來建立變數間的條件概率分布。通常可以分為兩個型別,一種是可以完全表示出資料確...
生成對抗網路
我們提出乙個框架來通過對抗方式評估生成模型,我們同時訓練兩個模型 乙個生成模型g捕捉資料分布,乙個鑑別模型d估計乙個樣本來自於訓練資料而不是g的概率。g的訓練過程是最大化d犯錯的概率。這個框架與minmax兩個玩家的遊戲相對應。在任意函式g和d的空間存在乙個唯一解,g恢復訓練資料的分布,d等於1 2...