近日研讀了一篇發表在iclr 2018上的文章:《learning latent permutations with gumbel- sinkhorn networks》, 其介紹了一種能夠將二維張量以可微分的形式轉變為轉置矩陣的方法。使得指派、重排等不可微分操作能夠以可微分的形式結合到神經網路當中。由此,我們便可使bp演算法學習這些操作,以實現神經網路的數字排序、拼圖等演算法。
其實我在最初使用神經網路分類時有乙個很幼稚的想法,對於最後的分類。能否設計這樣乙個損失函式:
l os
s_si
ngle
=0 & predict == y \\ 1 & predict \ != y \end
loss_s
ingl
e=\nidx:{}\nloss:{}\n"
.format
(x, idx, loss)
)> x:tensor([-
0.7181,-
0.2303,-
1.4065
,2.0853,-
0.9006
], requires_grad=
true
)> idx:
3> loss:tensor([1
], dtype=torch.int32)上面的邏輯粗略來看是沒問題的,但是,有乙個很重要的漏洞。我們呼叫了torch.max
函式,返回了**結果predict
,然後去和y
yy比較計算損失。
但是很遺憾:選取最高概率類別這個操作,即函式arg
maxi
(x)argmax_i(x)
argmax
i(x
)是不可導的。我們沒有辦法記錄這乙個操作的梯度。也就無法使用bp演算法更新網路(可以看到上方輸出中loss
並沒有記錄到梯度資訊).
既然上述方法失敗在:$argmax$
這個函式不可導上,那我們能不能進行解決呢,答案自然是可以的。簡單來說,我們可以通過以下可導函式近似argmax
函式(準確來說,是近似onehot(argmax)
函式:
s of
tmax
(xτ)
,τ→0
softmax(\frac), \tau \to 0
softma
x(τx
),τ
→0如果我們希望求得乙個最優排列,常見的,比如使用匈牙利演算法解決最優指派問題,同樣,這個選取最優指派的操作是不可導的,那麼,我們也就不能使用神經網路去學習這個問題。因此,模擬分類問題:我們能不能也使用乙個可導的操作去近似選取最優指派這個操作呢,從而使得可以被學習呢?答案是可以的
我們知道,乙個指派,實際上可以等價為乙個置換矩陣p
pp,如下所示:
[ 01
0100
001]
⏟p[1
23]⏟
x=[2
13]⏟
x\underbrace 0 & 1 & 0 \\ 1 & 0 & 0 \\ 0 & 0 & 1 \end }_ \underbrace 1 \\ 2 \\ 3 \end }_x =\underbrace 2 \\ 1 \\ 3 \end }_x
p⎣⎡01
010
000
1⎦⎤
x
⎣⎡1
23⎦
⎤
=x⎣⎡
213
⎦⎤
所以,我們能否可微地去近似置換矩陣p
pp呢,從而通過學習p
pp去學習指派這個操作呢?答案是可以,方法就是sinkhorn operator
。
給定乙個方陣x
xx. 我們可以通過以下變幻將其變為雙線性矩陣。(所謂雙線性矩陣,就是其每一行每一列的和都為1).
s 0(
x)=exp(
x)sl
(x)=
tc(t
r(sl
−1(x
)))s
(x)=
liml→
∞sl(
x)\begin s^(x) &=\exp (x) \\ s^(x) &=\mathcal_\left(\mathcal_\left(s^(x)\right)\right) \\ s(x) &=\lim _ s^(x) \end
s0(x)s
l(x)
s(x)
=exp(x)
=tc
(tr
(sl−
1(x)
))=l
→∞limsl
(x)
當然,對於指派問題,僅僅是雙線性矩陣還是不夠的,因為我們要保證$s(x)$
中的元素是非0即1的。而這個限制,我們可以通過增加乙個超引數$\tau$
實現:
m (x
)=limτ→
0+s(
x/τ)
m(x)=\lim _} s(x / \tau)
m(x)=τ
→0+lims
(x/τ
)其中,m(x
)=arg
maxp∈
pn⟨p
,x⟩f
m(x)=\underset_}\langle p, x\rangle_
m(x)=p
∈pn
argmax⟨
p,x⟩
f為對應收益矩陣為
x
xx的最優置換矩陣,
⟨a,
b⟩f=
trace(
a⊤b)
\langle a, b\rangle_=\operatorname\left(a^ b\right)
⟨a,b⟩f
=tr
ace(
a⊤b)
`這樣,我們通過神經網路去將原始資料編碼為矩陣x
xx, 再通過可微操作limτ
→0+s
(x/τ
)\lim _} s(x / \tau)
limτ→0
+s(
x/τ)
近似x
xx對應的指派m(x
)m(x)
m(x)
。最後就可以實現梯度更新從而訓練網路了。
下面是乙個實現拼圖的示意圖:
個人使用pytorch復現了一遍原文給出的數字排序實驗:
排列問題處理小技巧
乙個例子說明問題 反幻方我國古籍很早就記載著 2 9 4 7 5 3 6 1 8 這是乙個三階幻方。每行每列以及對角線上的數字相加都相等。下面考慮乙個相反的問題。可不可以用 1 9 的數字填入九宮格。使得 每行每列每個對角線上的數字和都互不相等呢?這應該能做到。比如 9 1 2 8 4 3 7 5 ...
列舉排列 生成可重集的排列
輸入陣列p,按照字典序輸出所有的p中元素的全排列 由於c c 語言中的函式在接受陣列引數的時候無法得知陣列的元素個數所以需要傳乙個已經填好的位置個數,或者當前需要確定元素位置cur,還需要傳輸原始陣列p,共四個引數 生成的排列有重複的所以需要統計a 0 a cur 1 中p i 出現的次數c1,以及...
生成可重集的排列
發現自己菜的連可重集的排列都求不出來。於是今天看懂 來分析一波。首先想能不能用非可重集做。顯然是錯的,因為同一種排列會被計算多次。所以,應該將同一類的數提取出來,統一填充。體現在演算法中,就是先排序,然後找出每一類的數,在個數不超的情況下,隨便填。include include using name...