Attention
此处提供一张思维导图。 思维导图 (Generated by NoteBookLLM) 为什么需要注意力机制 传统的 Seq2Seq 模型(此处以 RNN 的 Encoder - Decoder 模型为例),会将输入序列压缩为一个定长的向量,解码器再从这个向量生成输出序列。但是定长的向量难以有效编码所有必要的信息,那么就成为了处理长句子的瓶颈。 注意力机制的具体运作 注意力机制将输入编码成一个向量序列(annotations)。在生成输出序列的每个词的时候,模型会软搜索输入序列中的相关位置,根据这些相关的上下文向量和之前已经生成的目标词来预测下一个目标词。 缩放点积注意力(SDPA) $$ Atten(Q,K,V)=softmax\left( \frac{QK^T}{\sqrt{ d_{k} }} \right)V \tag{1} $$ 注意力机制的核心在于计算一个上下文向量$(Atten(Q,K,V))$,这个向量是输入序列的加权和,权重反应了输入序列中每个部分对于生成序列当前输出词的重要性。 在Scaled Dot-Product Attention 中,首先计算 query 和 key 的关联性,然后将这个关联性作为value 的权重,各个权重与 value 的乘积相加得到输出。(公式 1) $\sqrt{ d_{k} }$作用是缩放注意力分数。因为当$d_{k}$很大的时候,点积$QK^T$的结果会很大,导致 Softmax 产生极度不均匀的分布,梯度会变得很小。 代码实现 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 import torch import torch.nn as nn class ScaledDotProductAttention(nn.Module): def __init__(self): super(ScaledDotProductAttention, self).__init__() def forward(self, query, key, value, causal_mask=None,padding_mask=None): """ Single-head Scaled Dot-Product Attention Args: query: Query tensor of shape (batch_size, seq_len_q, d_k) key: Key tensor of shape (batch_size, seq_len_k, d_k) value: Value tensor of shape (batch_size, seq_len_v, d_v) causal_mask: Optional causal mask tensor of shape (batch_size, seq_len_q, seq_len_k) padding_mask: Optional padding mask tensor of shape (batch_size, seq_len_q, seq_len_k) 1. Causal mask is used to prevent attending to future tokens in the sequence. 2. Padding mask is used to ignore padding tokens in the sequence. 3. Both masks are optional and can be None. Returns: attention_output: Attention weighted output tensor of shape (batch_size, seq_len_q, d_v) """ d_k = query.size(-1) # Hidden size of the key/query attention_scores = torch.matmul(query,key.transpose(-1,-2)) / torch.sqrt(torch.tensor(d_k,dtype=torch.float32)) if causal_mask is not None: attention_scores = attention_scores.masked_fill(causal_mask == 0, float('-inf')) if padding_mask is not None: attention_scores = attention_scores.masked_fill(padding_mask == 0, float('-inf')) attention_weights = torch.softmax(attention_scores, dim=-1) attention_output = torch.matmul(attention_weights, value) return attention_output def test(): batch_size = 8 seq_len = 16 hidden_size = 64 query = torch.randn(batch_size,seq_len,hidden_size) key = torch.randn(batch_size,seq_len,hidden_size) value = torch.randn(batch_size,seq_len,hidden_size) sdpa = ScaledDotProductAttention() output = sdpa(query, key, value) print("Query shape:", query.shape) print("Key shape:", key.shape) print("Value shape:", value.shape) print("Output shape:", output.shape) if __name__ == "__main__": test() ...