(多頭)自注意力機制的PyTorch實現

2021-10-10 05:42:36 字數 3624 閱讀 9216

用於學習和複習的兩份自注意力機制實現**。

使用了縮放點積作為打分函式,因此key和query的維數是一樣的,實現很簡單。

from math import sqrt

import torch

import torch.nn as nn

class

selfattention

(nn.module)

: dim_in:

int dim_k:

int dim_v:

intdef

__init__

(self, dim_in, dim_k, dim_v)

:super

(selfattention, self)

.__init__(

) self.dim_in = dim_in

self.dim_k = dim_k

self.dim_v = dim_v

self.linear_q = nn.linear(dim_in, dim_k, bias=

false

) self.linear_k = nn.linear(dim_in, dim_k, bias=

false

) self.linear_v = nn.linear(dim_in, dim_v, bias=

false

) self._norm_fact =

1/ sqrt(dim_k)

defforward

(self, x)

:# x: batch, n, dim_in

batch, n, dim_in = x.shape

assert dim_in == self.dim_in

q = self.linear_q(x)

# batch, n, dim_k

k = self.linear_k(x)

# batch, n, dim_k

v = self.linear_v(x)

# batch, n, dim_v

dist = torch.bmm(q, k.transpose(1,

2))* self._norm_fact # batch, n, n

dist = torch.softmax(dist, dim=-1

)# batch, n, n

att = torch.bmm(dist, v)

return att

這裡為簡單起見沒有實現mask,若要實現,則在softmax前把需要mask的位置加上-np.inf就可以了,這樣兩個tensor進行矩陣乘法後,在需要mask掉的位置的分數就是負無窮,softmax後的注意力分布就是0。

上述自注意力機制的多頭版本,思路是使用乙個大矩陣把所有頭的所有q、k、v並行地計算出來,然後通過改變形狀(reshape)、和交換維度(permute)把多個頭的q、k、v放到同乙個batch中進行和單頭注意力相同的計算,最後再把多個頭的注意力向量拼接起來得到最後的值。

這裡平行計算多個頭的trick要注意。

from math import sqrt

import torch

import torch.nn as nn

class

multiheadselfattention

(nn.module)

: dim_in:

int# input dimension

dim_k:

int# key and query dimension

dim_v:

int# value dimension

num_heads:

int# number of heads, for each head, dim_* = dim_* // num_heads

def__init__

(self, dim_in, dim_k, dim_v, num_heads=8)

:super

(multiheadselfattention, self)

.__init__(

)assert dim_k % num_heads ==

0and dim_v % num_heads ==0,

"dim_k and dim_v must be multiple of num_heads"

self.dim_in = dim_in

self.dim_k = dim_k

self.dim_v = dim_v

self.num_heads = num_heads

self.linear_q = nn.linear(dim_in, dim_k, bias=

false

) self.linear_k = nn.linear(dim_in, dim_k, bias=

false

) self.linear_v = nn.linear(dim_in, dim_v, bias=

false

) self._norm_fact =

1/ sqrt(dim_k // num_heads)

defforward

(self, x)

:# x: tensor of shape (batch, n, dim_in)

batch, n, dim_in = x.shape

assert dim_in == self.dim_in

nh = self.num_heads

dk = self.dim_k // nh # dim_k of each head

dv = self.dim_v // nh # dim_v of each head

q = self.linear_q(x)

.reshape(batch, n, nh, dk)

.transpose(1,

2)# (batch, nh, n, dk)

k = self.linear_k(x)

.reshape(batch, n, nh, dk)

.transpose(1,

2)# (batch, nh, n, dk)

v = self.linear_v(x)

.reshape(batch, n, nh, dv)

.transpose(1,

2)# (batch, nh, n, dv)

dist = torch.matmul(q, k.transpose(2,

3))* self._norm_fact # batch, nh, n, n

dist = torch.softmax(dist, dim=-1

)# batch, nh, n, n

att = torch.matmul(dist, v)

# batch, nh, n, dv

att = att.transpose(1,

2).reshape(batch, n, self.dim_v)

# batch, n, dim_v

return att

mask的實現方式同上。

Self Attention 自注意力機制

self attention是提出transformer的 attention is all you need 中提出的一種新的注意力機制,這篇博文僅聚焦於self attention,不談transformer的其他機制。self attention直觀上與傳統seq2seq attention機...

自注意力機制總結

參考文獻 attention is all you need 和seq2seq模型一樣,transformer模型中也採用了encoder decoder結構。文章中encoder與decoder層都是由6個encoder decoder單元堆疊在一起。整體框架如上,但是看起來較複雜,可以簡化如下乙...

注意力機制

從網路結構本身的角度出發,可以從以下四個維度來提公升卷積神經網路的效能,分別是 深度 resnet 寬度 wideresnet 基數 resnext 和注意力 senet 一般來說,網路越深,所提取到的特徵就越抽象 網路越寬,其特徵就越豐富 基數越大,越能發揮每個卷積核獨特的作用 而注意力則是一種能...