gan最不好理解的就是loss函式的定義和訓練過程,這裡用一段**來輔助理解,就能明白到底是怎麼回事。其實gan的損失函式並沒有特殊之處,就是常用的binary_crossentropy,關鍵在於訓練過程中存在兩個神經網路和兩個損失函式。
np.random.seed(42)
tf.random.set_seed(42)
codings_size =
30generator = keras.models.sequential(
[ keras.layers.dense(
100, activation=
"selu"
, input_shape=
[codings_size]),
keras.layers.dense(
150, activation=
"selu"),
keras.layers.dense(28*
28, activation=
"sigmoid"),
keras.layers.reshape([28
,28])
])discriminator = keras.models.sequential(
[ keras.layers.flatten(input_shape=[28
,28])
, keras.layers.dense(
150, activation=
"selu"),
keras.layers.dense(
100, activation=
"selu"),
keras.layers.dense(
1, activation=
"sigmoid")]
)gan = keras.models.sequential(
[generator, discriminator]
)discriminator.
compile
(loss=
"binary_crossentropy"
, optimizer=
"rmsprop"
)discriminator.trainable =
false
gan.
compile
(loss=
"binary_crossentropy"
, optimizer=
"rmsprop"
)batch_size =
32dataset = tf.data.dataset.from_tensor_slices(x_train)
.shuffle(
1000
)dataset = dataset.batch(batch_size, drop_remainder=
true
).prefetch(
1)
這裡generator並不用compile,因為gan網路已經compile了。具體原因見下文。
訓練過程的**如下
def
train_gan
(gan, dataset, batch_size, codings_size, n_epochs=50)
: generator, discriminator = gan.layers
for epoch in
range
(n_epochs)
:print
("epoch {}/{}"
.format
(epoch +
1, n_epochs)
)# not shown in the book
for x_batch in dataset:
# phase 1 - training the discriminator
noise = tf.random.normal(shape=
[batch_size, codings_size]
) generated_images = generator(noise)
x_fake_and_real = tf.concat(
[generated_images, x_batch]
, axis=0)
y1 = tf.constant([[
0.]]
* batch_size +[[
1.]]
* batch_size)
discriminator.trainable =
true
discriminator.train_on_batch(x_fake_and_real, y1)
# phase 2 - training the generator
noise = tf.random.normal(shape=
[batch_size, codings_size]
) y2 = tf.constant([[
1.]]
* batch_size)
discriminator.trainable =
false
gan.train_on_batch(noise, y2)
plot_multiple_images(generated_images,8)
# not shown
plt.show(
)# not shown
第一階段(discriminator訓練)
# phase 1 - training the discriminator
noise = tf.random.normal(shape=
[batch_size, codings_size]
)generated_images = generator(noise)
x_fake_and_real = tf.concat(
[generated_images, x_batch]
, axis=0)
y1 = tf.constant([[
0.]]
* batch_size +[[
1.]]
* batch_size)
discriminator.trainable =
true
discriminator.train_on_batch(x_fake_and_real, y1)
這個階段首先生成數量相同的真實和假,concat在一起,即x_fake_and_real = tf.concat([generated_images, x_batch], axis=0)
。然後是label,真的label是1,假的label是0。
然後是迅速階段,首先將discrinimator設定為可訓練,discriminator.trainable = true
,然後開始階段。第乙個階段的訓練過程只訓練discriminator,discriminator.train_on_batch(x_fake_and_real, y1)
,而不是整個gan網路gan
。
第二階段(generator訓練)
# phase 2 - training the generator
noise = tf.random.normal(shape=
[batch_size, codings_size]
)y2 = tf.constant([[
1.]]
* batch_size)
discriminator.trainable =
false
gan.train_on_batch(noise, y2)
在第二階段首先生成假,但是不再生成真。把假的label全部設定為1,並把discriminator的權重凍結,即discriminator.trainable = false
。這一步很關鍵,應該這麼理解:
前面第一階段的是discriminator的訓練,使真的**值盡量接近1,假的**值盡量接近0,以此來達到優化損失函式的目的。現在將discrinimator的權重凍結,網路中輸入假,並故意把label設定為1。
注意,在整個gan網路中,從上向下的順序是先通過geneartor,再通過discriminator,即gan = keras.models.sequential([generator, discriminator])
。第二個階段將discrinimator凍結,並訓練網路gan.train_on_batch(noise, y2)
。如果generator生成的足夠真實,經過discrinimator後label會盡可能接近1。由於故意把y2的label設定為1,所以如果genrator生成的足夠真實,此時generator訓練已經達到最優狀態,不會大幅度更新權重;如果genrator生成的不夠真實,經過discriminator之後,**值會接近0,由於y2的label是1,相當於**值不準確,這時候gan網路的損失函式較大,generator會通過更新generator的權重來降低損失函式。
之後,重新回到第一階段訓練discriminator,然後第二階段訓練generator。假設整個gan網路達到理想狀態,這時候generator產生的假,經過discriminator之後,**值應該是0.5。假如這個值小於0.5,證明generator不是特別準確,在第二階段訓練過程中,generator的權重會被繼續更新。假如這個值大於0.5,證明discriminator不是特別準確,在第一階段訓練中,discriminator的權徵會被繼續更新。
簡單說,對於一張generator生成的假,discriminator會盡量把**值拉下拉,generator會盡量把**值往上扯,類似乙個拔河的過程,最後達到均衡狀態,例如0.6, 0.4, 0.55, 0.45, 0.51, 0.49, 0.50。
對抗神經網路(GAN)
對抗神經網路其實是兩個網路的組合,可以理解為乙個網路生成模擬資料,另乙個網路判斷生成的資料是真實的還是模擬的。生成模擬資料的網路要不斷優化自己讓判別的網路判斷不出來,判別的網路也要優化自己讓自己判斷得更準確。二者關係形成對抗,因此叫對抗神經網路。實驗證明,利用這種網路間的對抗關係所形成的網路,在無監...
GAN生成對抗神經網路原理(一)
1.基本原理 此處以生成為例進行說明 假設有2個網格,g generator 和d discriminator 功能分別是 g 生成的網格 接收乙個隨機的雜訊z,通過這個雜訊生成,記作g z d 判別網格,判別一張是不是 真實的 它的輸入引數是x,x代表一張,輸出d x 代表x真實的概率 若為1,代...
對抗神經網路的應用
接下來,我們要為你介紹一款能夠偽造人臉影象的ai neural face。neural face使用了facebook 人工智慧研究團隊開發的深度卷積神經網路 dcgan 研發團隊用由100個0到1的實數組成的1個向量z來代表每一張影象。通過計算出人類影象的分布,生成器就可以用高斯分布 gaussi...