Transformer
Transformer
引言:(补充)
用于机器翻译的transformer结构如下,由编码器组件和解码器组件构成。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-nHN4LNRR-1624956718031)(C:\Users\sfang.fangsh\AppData\Roaming\Typora\typora-user-images\image-20200616171932170.png)]
如上图,编码器组件有一系列编码器堆砌组成,解码器组件也是由一系列的解码器堆砌组成(上图各6个)。transformer的编码器和解码器是其核心组成,理解力编码器和解码器也就理解了transformer。
1 编码器和解码器
Encoder和Decoder的结构不同,Encoder主要由两部分组成:self-Attention和两层的前馈神经网络组成。Decoder有三部分组成:self-Attention,Encoder-Decoder Attention(类似seq2seq的attention机制),两层的前馈神经网络层。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-PVVsXl6c-1624956718033)(C:\Users\sfang.fangsh\AppData\Roaming\Typora\typora-user-images\image-20200616172232714.png)]
2 编码器Encoder
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ehwPkCQx-1624956718034)(C:\Users\sfang.fangsh\AppData\Roaming\Typora\typora-user-images\image-20200616172642821.png)]
encoder的输入是字嵌入+位置编码:
2.1 self attention
self attention是多头的注意力机制。假设head的个数是8,其计算方法是:
(1) 计算词嵌入向量和位置编码,相加得到输入X
  
      
       
        
        
          X 
         
        
          = 
         
        
          w 
         
        
          o 
         
        
          r 
         
        
          d 
         
        
          _ 
         
        
          e 
         
        
          m 
         
        
          b 
         
        
          e 
         
        
          d 
         
        
          d 
         
        
          i 
         
        
          n 
         
        
          g 
         
        
          + 
         
        
          p 
         
        
          o 
         
        
          s 
         
        
          i 
         
        
          t 
         
        
          i 
         
        
          o 
         
        
          n 
         
        
          a 
         
        
          l 
         
        
          _ 
         
        
          e 
         
        
          n 
         
        
          c 
         
        
          o 
         
        
          d 
         
        
          i 
         
        
          n 
         
        
          g 
         
        
       
         X = word\_embedding + positional\_encoding 
        
       
     X=word_embedding+positional_encoding
 (2) 计算每个头的Q,K,V矩阵
  
      
       
        
         
         
           Q 
          
         
           i 
          
         
        
          = 
         
        
          X 
         
         
         
           W 
          
         
           i 
          
         
           Q 
          
         
        
          , 
         
         
         
         
           K 
          
         
           i 
          
         
        
          = 
         
        
          X 
         
         
         
           W 
          
         
           i 
          
         
           K 
          
         
        
          , 
         
         
         
         
           V 
          
         
           i 
          
         
        
          = 
         
        
          X 
         
         
         
           W 
          
         
           i 
          
         
           V 
          
         
        
       
         Q_i=XW_i^Q,\qquad K_i=XW_i^K,\qquad V_i=XW_i^V 
        
       
     Qi=XWiQ,Ki=XWiK,Vi=XWiV
 (3) 用向量点乘计算当前位置Q,与其他位置K的相关度。并除以 
     
      
       
        
         
         
           d 
          
         
           m 
          
         
        
       
      
        \sqrt{d_m} 
       
      
    dm归一化
  
      
       
        
        
          s 
         
        
          c 
         
        
          o 
         
        
          r 
         
         
         
           e 
          
          
          
            i 
           
          
            j 
           
          
         
        
          = 
         
         
          
           
           
             Q 
            
           
             i 
            
           
             T 
            
           
           
           
             K 
            
           
             j 
            
           
          
          
           
           
             d 
            
           
             m 
            
           
          
         
        
       
         score_{ij}=\frac{Q_i^TK_j}{\sqrt{d_m}} 
        
       
     scoreij=dmQiTKj
 (4) 使用softmax将各自得分转为权重。
  
      
       
        
         
         
           w 
          
          
          
            i 
           
          
            j 
           
          
         
        
          = 
         
        
          s 
         
        
          o 
         
        
          f 
         
        
          t 
         
        
          m 
         
        
          a 
         
        
          x 
         
        
          ( 
         
        
          s 
         
        
          c 
         
        
          o 
         
        
          r 
         
         
         
           e 
          
          
          
            i 
           
          
            j 
           
          
         
        
          ) 
         
        
       
         w_{ij}=softmax(score_{ij}) 
        
       
     wij=softmax(scoreij)
 (5) 计算权重和
  
      
       
        
         
         
           Z 
          
         
           i 
          
         
        
          = 
         
         
         
           ∑ 
          
         
           j 
          
         
         
         
           w 
          
          
          
            i 
           
          
            j 
           
          
         
         
         
           V 
          
         
           j 
          
         
        
       
         Z_i=\sum_jw_{ij}V_j 
        
       
     Zi=j∑wijVj
 (6) 拼接得到多头注意力机制的结果
  
      
       
        
        
          Z 
         
        
          = 
         
        
          c 
         
        
          o 
         
        
          n 
         
        
          c 
         
        
          a 
         
        
          t 
         
        
          ( 
         
         
         
           Z 
          
         
           1 
          
         
        
          , 
         
         
         
           Z 
          
         
           2 
          
         
        
          , 
         
        
          . 
         
        
          . 
         
        
          . 
         
        
          , 
         
         
         
           Z 
          
         
           8 
          
         
        
          ) 
         
         
         
           W 
          
         
           0 
          
         
        
       
         Z=concat(Z_1,Z_2,...,Z_8)W_0 
        
       
     Z=concat(Z1,Z2,...,Z8)W0
 (7) 残差结构,得到self-attention的结果
  
      
       
        
        
          Z 
         
        
          = 
         
        
          L 
         
        
          a 
         
        
          y 
         
        
          e 
         
        
          r 
         
        
          N 
         
        
          o 
         
        
          r 
         
        
          m 
         
        
          ( 
         
        
          X 
         
        
          + 
         
        
          Z 
         
        
          ) 
         
        
       
         Z=LayerNorm(X+Z) 
        
       
     Z=LayerNorm(X+Z)
 2.2 前馈神经网络
主要有两层全连接组成。个人理解该层主要是为了组合多头注意力各自提取的信息。得到更好的表示。有三个操作:
(1) 全连接映射到高维,并使用激活函数。
  
      
       
        
        
          F 
         
        
          F 
         
         
         
           N 
          
         
           1 
          
         
        
          ( 
         
        
          Z 
         
        
          ) 
         
        
          = 
         
        
          g 
         
        
          e 
         
        
          l 
         
        
          u 
         
        
          ( 
         
        
          W 
         
        
          X 
         
        
          ) 
         
        
       
         FFN_1(Z)=gelu(WX) 
        
       
     FFN1(Z)=gelu(WX)
 (2) 全连接映射回原来维度,便于多层encoder叠加。
  
      
       
        
        
          F 
         
        
          F 
         
         
         
           N 
          
         
           2 
          
         
        
          ( 
         
        
          Z 
         
        
          ) 
         
        
          = 
         
        
          W 
         
        
          F 
         
        
          F 
         
         
         
           N 
          
         
           1 
          
         
        
          ( 
         
        
          Z 
         
        
          ) 
         
        
       
         FFN_2(Z)=WFFN_1(Z) 
        
       
     FFN2(Z)=WFFN1(Z)
 (3) add&&Norm,残差结构并进行Layer Normlization。
  
      
       
        
        
          X 
         
        
          = 
         
        
          L 
         
        
          a 
         
        
          y 
         
        
          e 
         
        
          r 
         
        
          N 
         
        
          o 
         
        
          r 
         
        
          m 
         
        
          ( 
         
        
          F 
         
        
          F 
         
         
         
           N 
          
         
           2 
          
         
        
          ( 
         
        
          Z 
         
        
          ) 
         
        
          + 
         
        
          Z 
         
        
          ) 
         
        
       
         X=LayerNorm(FFN_2(Z)+Z) 
        
       
     X=LayerNorm(FFN2(Z)+Z)
 3 解码器decoder
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-RdswbjJ9-1624956718035)(C:\Users\sfang.fangsh\AppData\Roaming\Typora\typora-user-images\image-20200616181311559.png)]
解码器解码时,需要用到三部分信息:
- 之前时刻的输出结果
 - 编码器的K,V矩阵,由编码器最上层encoder输出做线性变换得到。
 - 位置编码,decoder的输入向量要需要加上位置编码
 
包括三个部分:self-attention,Encoder-Decoder attention,feed forward netword。
先讨论在训练时的计算过程。seq2seq模型训练一版使用teacher forcing技巧,即在预测第i个输出时,将i时刻之前的标签值作为输入。于是,decoder的训练输入是标签y嵌入后的结果。
3.1 self-attention
(1)计算输入
  
      
       
        
        
          X 
         
        
          = 
         
        
          w 
         
        
          o 
         
        
          r 
         
        
          d 
         
        
          _ 
         
        
          e 
         
        
          m 
         
        
          b 
         
        
          e 
         
        
          d 
         
        
          d 
         
        
          i 
         
        
          n 
         
        
          g 
         
        
          + 
         
        
          p 
         
        
          o 
         
        
          s 
         
        
          i 
         
        
          t 
         
        
          i 
         
        
          o 
         
        
          n 
         
        
          a 
         
        
          l 
         
        
          _ 
         
        
          e 
         
        
          n 
         
        
          c 
         
        
          o 
         
        
          d 
         
        
          i 
         
        
          n 
         
        
          g 
         
        
       
         X = word\_embedding + positional\_encoding 
        
       
     X=word_embedding+positional_encoding
 这里是标签的word embedding,同样加上位置编码。
(2)计算Masked Multi-head attention值
与encoder的self-attention不同,在预测第i个位置输出时,i之后位置的结果被掩盖。训练时只能看到当前词之前的信息。
  
      
       
        
         
         
           Q 
          
         
           i 
          
         
        
          = 
         
        
          X 
         
         
         
           W 
          
         
           i 
          
         
           Q 
          
         
        
          , 
         
         
         
         
           K 
          
         
           i 
          
         
        
          = 
         
        
          X 
         
         
         
           W 
          
         
           i 
          
         
           K 
          
         
        
          , 
         
         
         
         
           V 
          
         
           i 
          
         
        
          = 
         
        
          X 
         
         
         
           W 
          
         
           i 
          
         
        
          V 
         
         
         
         
           Z 
          
         
           i 
          
         
        
          = 
         
        
          A 
         
        
          t 
         
        
          t 
         
        
          e 
         
        
          n 
         
        
          t 
         
        
          i 
         
        
          o 
         
        
          n 
         
        
          ( 
         
         
         
           Q 
          
         
           i 
          
         
        
          , 
         
         
         
           K 
          
         
           i 
          
         
        
          , 
         
         
         
           V 
          
         
           i 
          
         
        
          ) 
         
        
          = 
         
        
          s 
         
        
          o 
         
        
          f 
         
        
          t 
         
        
          m 
         
        
          a 
         
        
          x 
         
        
          ( 
         
        
          m 
         
        
          a 
         
        
          s 
         
        
          k 
         
        
          ( 
         
         
          
           
           
             Q 
            
           
             i 
            
           
             T 
            
           
          
            ⋅ 
           
           
           
             K 
            
           
             i 
            
           
          
          
           
           
             d 
            
           
             m 
            
           
          
         
        
          ) 
         
        
          ) 
         
         
         
           V 
          
         
           i 
          
         
        
          , 
         
        
          = 
         
        
          1 
         
        
          , 
         
        
          2 
         
        
          , 
         
        
          . 
         
        
          . 
         
        
          . 
         
        
          , 
         
        
          8 
         
         
        
          Z 
         
        
          = 
         
        
          M 
         
        
          u 
         
        
          l 
         
        
          t 
         
        
          i 
         
        
          h 
         
        
          e 
         
        
          a 
         
        
          d 
         
        
          ( 
         
        
          Q 
         
        
          , 
         
        
          K 
         
        
          , 
         
        
          V 
         
        
          ) 
         
        
          = 
         
        
          C 
         
        
          o 
         
        
          n 
         
        
          c 
         
        
          a 
         
        
          t 
         
        
          ( 
         
         
         
           Z 
          
         
           1 
          
         
        
          , 
         
         
         
           Z 
          
         
           2 
          
         
        
          , 
         
        
          . 
         
        
          . 
         
        
          . 
         
        
          , 
         
         
         
           Z 
          
         
           8 
          
         
        
          ) 
         
         
         
           W 
          
         
           0 
          
         
         
        
          Z 
         
        
          = 
         
        
          L 
         
        
          a 
         
        
          y 
         
        
          e 
         
        
          r 
         
        
          N 
         
        
          o 
         
        
          r 
         
        
          m 
         
        
          ( 
         
        
          Z 
         
        
          + 
         
        
          X 
         
        
          ) 
         
        
       
         Q_i=XW_i^Q,\qquad K_i=XW_i^K,\qquad V_i=XW_iV\\ Z_i=Attention(Q_i,K_i,V_i)=softmax(mask(\frac{Q_i^T\cdot K_i}{\sqrt{d_m}}))V_i,=1,2,...,8\\ Z=Multihead(Q,K,V)=Concat(Z_1,Z_2,...,Z_8)W_0\\ Z=LayerNorm(Z+X) 
        
       
     Qi=XWiQ,Ki=XWiK,Vi=XWiVZi=Attention(Qi,Ki,Vi)=softmax(mask(dmQiT⋅Ki))Vi,=1,2,...,8Z=Multihead(Q,K,V)=Concat(Z1,Z2,...,Z8)W0Z=LayerNorm(Z+X)
 再求加权和前,乘以mask矩阵来达到掩盖位置i之后的信息。mask矩阵如下图,上三角为0,下三角为1。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-SBs78ocP-1624956718036)(C:\Users\sfang.fangsh\AppData\Roaming\Typora\typora-user-images\image-20200616202003123.png)]
3.2 Encoder-Decoder attention
(1)和Encoder的self-attention计算方式很相似,只是K,V,矩阵来自于encoder的输出 
     
      
       
        
        
          K 
         
        
          i 
         
         
         
           e 
          
         
           n 
          
         
           d 
          
         
           e 
          
         
           c 
          
         
        
       
         , 
        
        
        
          V 
         
        
          i 
         
         
         
           e 
          
         
           n 
          
         
           d 
          
         
           e 
          
         
           c 
          
         
        
       
      
        K_i^{endec},V_i^{endec} 
       
      
    Kiendec,Viendec,使用上一层的输出Z计算得到 
     
      
       
        
        
          Q 
         
        
          i 
         
        
       
      
        Q_i 
       
      
    Qi。
  
      
       
        
         
         
           Q 
          
         
           i 
          
         
        
          = 
         
        
          X 
         
         
         
           W 
          
         
           i 
          
         
           Q 
          
         
        
          , 
         
         
         
           K 
          
         
           i 
          
          
          
            e 
           
          
            n 
           
          
            d 
           
          
            e 
           
          
            c 
           
          
         
        
          , 
         
         
         
           V 
          
         
           i 
          
          
          
            e 
           
          
            n 
           
          
            d 
           
          
            e 
           
          
            c 
           
          
         
         
         
         
           Z 
          
         
           i 
          
         
        
          = 
         
        
          A 
         
        
          t 
         
        
          t 
         
        
          e 
         
        
          n 
         
        
          t 
         
        
          i 
         
        
          o 
         
        
          n 
         
        
          ( 
         
         
         
           Q 
          
         
           i 
          
         
        
          , 
         
         
         
           K 
          
         
           i 
          
          
          
            e 
           
          
            n 
           
          
            d 
           
          
            e 
           
          
            c 
           
          
         
        
          , 
         
         
         
           V 
          
         
           i 
          
          
          
            e 
           
          
            n 
           
          
            d 
           
          
            e 
           
          
            c 
           
          
         
        
          ) 
         
        
          = 
         
        
          s 
         
        
          o 
         
        
          f 
         
        
          t 
         
        
          m 
         
        
          a 
         
        
          x 
         
        
          ( 
         
         
          
           
           
             Q 
            
           
             i 
            
           
             T 
            
           
          
            ⋅ 
           
           
           
             K 
            
           
             i 
            
            
            
              e 
             
            
              n 
             
            
              d 
             
            
              e 
             
            
              c 
             
            
           
          
          
           
           
             d 
            
           
             m 
            
           
          
         
        
          ) 
         
         
         
           V 
          
         
           i 
          
          
          
            e 
           
          
            n 
           
          
            d 
           
          
            e 
           
          
            c 
           
          
         
        
          , 
         
        
          i 
         
        
          = 
         
        
          1 
         
        
          , 
         
        
          2 
         
        
          , 
         
        
          . 
         
        
          . 
         
        
          . 
         
        
          , 
         
        
          8 
         
         
        
          Z 
         
        
          = 
         
        
          M 
         
        
          u 
         
        
          l 
         
        
          t 
         
        
          i 
         
        
          h 
         
        
          e 
         
        
          a 
         
        
          d 
         
        
          ( 
         
        
          Q 
         
        
          , 
         
        
          K 
         
         
         
           e 
          
         
           n 
          
         
           d 
          
         
           e 
          
         
           c 
          
         
        
          , 
         
         
         
           V 
          
          
          
            e 
           
          
            n 
           
          
            d 
           
          
            e 
           
          
            c 
           
          
         
        
          ) 
         
        
          = 
         
        
          C 
         
        
          o 
         
        
          n 
         
        
          c 
         
        
          a 
         
        
          t 
         
        
          ( 
         
         
         
           Z 
          
         
           1 
          
         
        
          , 
         
         
         
           Z 
          
         
           2 
          
         
        
          , 
         
        
          . 
         
        
          . 
         
        
          . 
         
        
          , 
         
         
         
           Z 
          
         
           8 
          
         
        
          ) 
         
         
         
           W 
          
         
           0 
          
         
         
        
          Z 
         
        
          = 
         
        
          L 
         
        
          a 
         
        
          y 
         
        
          e 
         
        
          r 
         
        
          N 
         
        
          o 
         
        
          r 
         
        
          m 
         
        
          ( 
         
        
          X 
         
        
          + 
         
        
          Z 
         
        
          ) 
         
        
       
         Q_i=XW_i^Q,K_i^{endec},V_i^{endec}\\ Z_i=Attention(Q_i,K_i^{endec},V_i^{endec})=softmax(\frac{Q_i^T\cdot K_i^{endec}}{\sqrt{d_m}})V_i^{endec},i=1,2,...,8\\ Z=Multihead(Q,K{endec},V^{endec})=Concat(Z_1,Z_2,...,Z_8)W_0\\ Z=LayerNorm(X+Z) 
        
       
     Qi=XWiQ,Kiendec,ViendecZi=Attention(Qi,Kiendec,Viendec)=softmax(dmQiT⋅Kiendec)Viendec,i=1,2,...,8Z=Multihead(Q,Kendec,Vendec)=Concat(Z1,Z2,...,Z8)W0Z=LayerNorm(X+Z)
 3.3 Feed forward netword 两次前馈网络
  
      
       
        
        
          F 
         
        
          F 
         
         
         
           N 
          
         
           1 
          
         
        
          ( 
         
        
          Z 
         
        
          ) 
         
        
          = 
         
        
          g 
         
        
          e 
         
        
          l 
         
        
          u 
         
        
          ( 
         
        
          W 
         
        
          Z 
         
        
          ) 
         
         
        
          F 
         
        
          F 
         
         
         
           N 
          
         
           2 
          
         
        
          ( 
         
        
          Z 
         
        
          ) 
         
        
          = 
         
        
          W 
         
        
          F 
         
        
          F 
         
         
         
           N 
          
         
           1 
          
         
        
          ( 
         
        
          X 
         
        
          ) 
         
         
        
          X 
         
        
          = 
         
        
          L 
         
        
          a 
         
        
          y 
         
        
          e 
         
        
          r 
         
        
          N 
         
        
          o 
         
        
          r 
         
        
          m 
         
        
          ( 
         
        
          Z 
         
        
          + 
         
        
          F 
         
        
          F 
         
         
         
           N 
          
         
           2 
          
         
        
          ( 
         
        
          Z 
         
        
          ) 
         
        
          ) 
         
        
       
         FFN_1(Z)=gelu(WZ)\\ FFN_2(Z)=WFFN_1(X)\\ X=LayerNorm(Z+FFN_2(Z)) 
        
       
     FFN1(Z)=gelu(WZ)FFN2(Z)=WFFN1(X)X=LayerNorm(Z+FFN2(Z))
4 position encoding
三角函数不是必须的,位置编码可以用参与学习的positional embedding。
用三角函数是它满足以下特性:
(1)绝对位置差异性
绝对位置不同的位置编码存在差异。
(2)相对位置的稳定性
间隔为K的任意两个位置编码的欧式距离相等。相对欧式距离只与k有关。
参考:
[1]一文看到attentionhttps://easyai.tech/ai-definition/attention/
[2]Transformer原理详解https://zhuanlan.zhihu.com/p/127774251
[3]nlp中的预训练语言模型总结(单向模型、BERT系列模型、XLNet)https://zhuanlan.zhihu.com/p/76912493