跳转到内容

Transformer 与自注意力机制

快速概览

Transformer 通过自注意力机制直接建模序列中任意两个位置之间的关系,无需递归计算,是目前最强大的序列建模架构,是现代生物语言模型的基础。

  • 自注意力机制捕获长程依赖
  • 并行计算,训练效率高
  • 是 ESM、DNABERT 等生物语言模型的核心
所属板块 分析方向与案例

把基础对象与算法方法重新放回真实分析任务与工作流。

阅读目标 帮助建立阅读上下文

先判断这页与你当前问题的关系,再决定是否深入展开。

建议前置 先建立相关基础对象与方法直觉

建议先建立相关基础对象与方法直觉,再进入本页。

CNN 擅长局部模式检测,RNN 擅长建模顺序依赖,但生物序列中还存在长程相互作用问题:

蛋白质结构的例子:形成稳定三维结构时,氨基酸残基之间可能存在远距离接触:

  • 二硫键:半胱氨酸残基可能在序列上相距很远,但在空间上靠近
  • 疏水核心:多个疏水残基需要在三维空间中聚集
  • β-折叠片:来自不同位置的残基形成氢键配对

基因调控的例子:增强子(enhancer)可以在距离启动子数 kb 甚至数十 kb 的位置调控基因表达:

  • 染色质环化使远距离元件物理接触
  • 多个 TF 结合位点协同作用
  • 这些相互作用与线性距离不成正比

这类问题的共同特点是:序列上远离的位置之间存在功能关联。CNN 的感受野有限,RNN 的梯度在长距离时会消失,而 Transformer 通过全局自注意力直接建模任意两个位置的关系。

RNN 顺序处理序列,第 t 步必须等待第 t-1 步完成。Transformer 的关键创新是完全并行化

  1. 一次性读取整个序列
  2. 每个位置同时计算与所有其他位置的”注意力权重”
  3. 根据权重聚合信息
  4. 所有位置同时更新表示

想象一个会议:RNN 像依次发言(必须听完前面的人才能说),而 Transformer 像所有人同时阅读所有材料,然后同时写下自己的总结,基于对全部材料的理解。

自注意力的核心问题是:对于序列中的每个位置,哪些其他位置与它最相关?

对每个位置:

  1. 生成 Query 向量(“我要找什么信息”)
  2. 所有位置生成 Key 向量(“我有什么信息”)
  3. 计算 Query-Key 匹配得分(相关性)
  4. 用得分加权聚合 Value 向量(获取相关信息)

这种机制使每个位置都能”看到”整个序列,并根据相关性选择性关注。

应用为什么需要 Transformer典型代表
蛋白质结构预测氨基酸间长程接触决定 3D 结构AlphaFold2 (Evoformer)
大规模预训练并行训练效率高,可处理海量数据ESM、DNABERT、ProtBERT
调控元件相互作用增强子-启动子长距离调控Enformer
变异效应预测变异影响可能波及远处功能位点VEP 工具中的 Transformer 模型
单细胞多组学整合不同模态的全局特征scGPT、scFoundation
特性CNNRNNTransformer
信息流动局部窗口滑动顺序传递全局直接连接
长程依赖需多层堆叠梯度衰减直接建模
计算并行度低(顺序)
位置感知通过卷积位置通过递归顺序通过位置编码
归纳偏置局部性时序性无(全连接)

Transformer 的强大来自于移除强归纳偏置:它不像 CNN 假设局部性,也不像 RNN 假设时序性,而是让数据自己学习哪些位置应该相互作用。代价是需要更多数据和计算来学习这些模式。

给定输入序列 X = [x₁, x₂, …, x_L],每个 x_t ∈ ℝ^d。

首先通过线性变换得到 Query、Key、Value:

Q=XWQ,K=XWK,V=XWVQ = X W_Q, \quad K = X W_K, \quad V = X W_V

其中 W_Q, W_K, W_V ∈ ℝ^(d×d_k) 是可学习参数。

注意力权重计算:

{Attention}(Q,K,V)={softmax}({QKT}{{dk}})V\text\{Attention\}(Q, K, V) = \text\{softmax\}\left(\frac\{QK^T\}\{\sqrt\{d_k\}\}\right)V

其中:

  • QK^T 计算所有位置对的相似度
  • 除以 √d_k 进行缩放,防止梯度消失
  • softmax 归一化得到注意力权重
  • 与 V 加权求和得到输出

使用多组 Q、K、V 捕获不同类型的关系:

{MultiHead}(Q,K,V)={Concat}({head}1,...,{head}h)WO\text\{MultiHead\}(Q, K, V) = \text\{Concat\}(\text\{head\}_1, ..., \text\{head\}_h)W^O

其中每个 head 的计算:

{head}i={Attention}(QWiQ,KWiK,VWiV)\text\{head\}_i = \text\{Attention\}(QW_i^Q, KW_i^K, VW_i^V)

由于自注意力本身不包含位置信息,需要添加位置编码:

PE{(pos,2i)}=sin({pos}{10000{2i/d}})PE_\{(pos, 2i)\} = \sin\left(\frac\{pos\}\{10000^\{2i/d\}\}\right) PE{(pos,2i+1)}=cos({pos}{10000{2i/d}})PE_\{(pos, 2i+1)\} = \cos\left(\frac\{pos\}\{10000^\{2i/d\}\}\right)

最终输入:

X=X+PEX' = X + PE

每个位置独立应用前馈网络:

{FFN}(x)=max(0,xW1+b1)W2+b2\text\{FFN\}(x) = \max(0, xW_1 + b_1)W_2 + b_2 {LayerNorm}(x)={xμ}{σ}γ+β\text\{LayerNorm\}(x) = \frac\{x - \mu\}\{\sigma\} \odot \gamma + \beta {Output}={LayerNorm}(x+{Sublayer}(x))\text\{Output\} = \text\{LayerNorm\}(x + \text\{Sublayer\}(x)) X={LayerNorm}(X+{MultiHeadAttention}(X))X' = \text\{LayerNorm\}(X + \text\{MultiHeadAttention\}(X)) X={LayerNorm}(X+{FFN}(X))X'' = \text\{LayerNorm\}(X' + \text\{FFN\}(X'))

将序列编码为嵌入表示,添加位置编码

  1. 计算每个位置的 Query、Key、Value
  2. 计算注意力权重矩阵
  3. 加权求和得到注意力输出
  4. 拼接多头结果并线性变换

对每个位置独立应用前馈网络

应用残差连接和层归一化

重复步骤 2-4 多次(通常 6-12 层)

根据任务类型设计输出层:

  • 分类:全局池化 + 线性层
  • 序列标注:每个位置独立分类
  • 回归:线性输出

考虑一个简化的蛋白质序列任务:使用 Transformer 预测氨基酸的二级结构。

序列片段:S = “ACDE”(长度 L = 4)

假设嵌入维度 d = 4,注意力头数 h = 2,每个头的维度 d_k = 2。

假设嵌入表示(简化数值):

X = [
[1.0, 0.5, -0.3, 0.2], # A
[0.3, 0.8, 0.1, -0.5], # C
[-0.2, 0.4, 0.9, 0.1], # D
[0.6, -0.1, 0.3, 0.7], # E
]

位置编码(简化):

PE = [
[0.0, 1.0, 0.0, 1.0],
[0.5, 0.9, 0.5, 0.9],
[0.8, 0.6, 0.8, 0.6],
[1.0, 0.0, 1.0, 0.0],
]

添加位置编码:

X' = X + PE = [
[1.0, 1.5, -0.3, 1.2],
[0.8, 1.7, 0.6, 0.4],
[0.6, 1.0, 1.7, 0.7],
[1.6, -0.1, 1.3, 0.7],
]

假设权重矩阵(简化):

W_Q = [
[0.5, 0.3],
[-0.2, 0.4],
[0.1, -0.5],
[0.3, 0.2],
]

计算 Q = X’ W_Q:

Q = [
[1.0×0.5 + 1.5×(-0.2) + (-0.3)×0.1 + 1.2×0.3, 1.0×0.3 + 1.5×0.4 + (-0.3)×(-0.5) + 1.2×0.2],
...
]

类似地计算 K 和 V。

假设得到 Q、K、V:

Q = [
[0.5, 0.8], # A
[0.3, 0.6], # C
[0.7, 0.4], # D
[0.2, 0.9], # E
]
K = [
[0.4, 0.7],
[0.6, 0.3],
[0.2, 0.8],
[0.5, 0.5],
]
V = [
[0.9, 0.2],
[0.3, 0.7],
[0.6, 0.4],
[0.1, 0.8],
]

计算注意力分数 QK^T:

QK^T = [
[0.5×0.4+0.8×0.7, 0.5×0.6+0.8×0.3, 0.5×0.2+0.8×0.8, 0.5×0.5+0.8×0.5],
[0.3×0.4+0.6×0.7, 0.3×0.6+0.6×0.3, 0.3×0.2+0.6×0.8, 0.3×0.5+0.6×0.5],
[0.7×0.4+0.4×0.7, 0.7×0.6+0.4×0.3, 0.7×0.2+0.4×0.8, 0.7×0.5+0.4×0.5],
[0.2×0.4+0.9×0.7, 0.2×0.6+0.9×0.3, 0.2×0.2+0.9×0.8, 0.2×0.5+0.9×0.5],
]
QK^T = [
[0.76, 0.54, 0.74, 0.65],
[0.54, 0.36, 0.54, 0.45],
[0.56, 0.54, 0.46, 0.55],
[0.71, 0.39, 0.76, 0.55],
]

缩放(除以 √d_k = √2 ≈ 1.41):

Scaled = QK^T / 1.41 = [
[0.54, 0.38, 0.52, 0.46],
[0.38, 0.26, 0.38, 0.32],
[0.40, 0.38, 0.33, 0.39],
[0.50, 0.28, 0.54, 0.39],
]

Softmax 归一化(以第一行为例):

{softmax}([0.54,0.38,0.52,0.46])={e{[0.54,0.38,0.52,0.46]}}{e{[0.54,0.38,0.52,0.46]}}\text\{softmax\}([0.54, 0.38, 0.52, 0.46]) = \frac\{e^\{[0.54, 0.38, 0.52, 0.46]\}\}\{\sum e^\{[0.54, 0.38, 0.52, 0.46]\}\} ={[1.72,1.46,1.68,1.58]}{1.72+1.46+1.68+1.58}={[1.72,1.46,1.68,1.58]}{6.44}=[0.27,0.23,0.26,0.24]= \frac\{[1.72, 1.46, 1.68, 1.58]\}\{1.72+1.46+1.68+1.58\} = \frac\{[1.72, 1.46, 1.68, 1.58]\}\{6.44\} = [0.27, 0.23, 0.26, 0.24]

完整注意力权重矩阵:

A = [
[0.27, 0.23, 0.26, 0.24],
[0.28, 0.22, 0.28, 0.22],
[0.26, 0.25, 0.24, 0.25],
[0.28, 0.21, 0.29, 0.22],
]

计算输出 Attention = A × V:

Output = [
[0.27×0.9+0.23×0.3+0.26×0.6+0.24×0.1, 0.27×0.2+0.23×0.7+0.26×0.4+0.24×0.8],
...
]

第一行:

[0.27×0.9+0.23×0.3+0.26×0.6+0.24×0.1,0.27×0.2+0.23×0.7+0.26×0.4+0.24×0.8][0.27×0.9+0.23×0.3+0.26×0.6+0.24×0.1, 0.27×0.2+0.23×0.7+0.26×0.4+0.24×0.8] =[0.243+0.069+0.156+0.024,0.054+0.161+0.104+0.192]=[0.49,0.51]= [0.243+0.069+0.156+0.024, 0.054+0.161+0.104+0.192] = [0.49, 0.51]

假设第二个头的输出为 [0.3, 0.7],拼接并线性变换得到最终输出。

这个简化的例子展示了 Transformer 如何通过自注意力机制计算每个位置对所有位置的依赖关系。

  • 自注意力:O(L² × d),其中 L 是序列长度,d 是嵌入维度
  • 前馈网络:O(L × d²)
  • 单层总复杂度:O(L² × d + L × d²)
  • 多层总复杂度:O(n_layers × (L² × d + L × d²))

与序列长度平方相关,对长序列计算成本高。

  • 存储注意力矩阵:O(L²)
  • 存储参数:O(n_layers × d²)
  • 总空间复杂度:O(L² + n_layers × d²)

Transformer 在生物信息学中的代表性应用:

  • AlphaFold2(2021):基于 Transformer 的蛋白质结构预测
  • ESM(2021):蛋白质语言模型
  • DNABERT(2020):DNA 语言模型
  • ProtBERT(2020):蛋白质语言模型
  • Enformer(2021):基因表达预测
  • HyenaDNA(2023):长序列 DNA 建模

这些工具的特点:

  • 利用自注意力捕获长程依赖
  • 大规模预训练后迁移学习
  • 在多个下游任务上表现优异
特性CNNRNNTransformer
局部模式检测
长程依赖
并行计算
序列长度限制梯度问题O(L²) 复杂度
参数量
可解释性

降低复杂度到 O(L × d):

{Attention}(Q,K,V)=ϕ(Q)(ϕ(K)TV)\text\{Attention\}(Q, K, V) = \phi(Q) (\phi(K)^T V)

其中 φ 是核函数。

限制每个位置只关注局部窗口:

{Attention}i={j=iw}{i+w}α{ij}Vj\text\{Attention\}_i = \sum_\{j=i-w\}^\{i+w\} \alpha_\{ij\} V_j

其中 w 是窗口大小。

使用稀疏模式降低计算量:

  • 全局注意力:少数位置关注所有位置
  • 带状注意力:每个位置关注局部带状区域
  • 随机注意力:随机选择注意力模式

使用相对位置而非绝对位置:

e{ij}={(xiWQ)(xjWK+r{ij})T}{{dk}}e_\{ij\} = \frac\{(x_i W_Q)(x_j W_K + r_\{i-j\})^T\}\{\sqrt\{d_k\}\}

其中 $r_{i-j}$ 是相对位置嵌入。

通过旋转操作编码相对位置:

f(x,m)=R{Θ,m}xf(x, m) = R_\{\Theta,m\} x

其中 RΘ,mR_{\Theta,m} 是旋转矩阵。

优化注意力计算的内存访问模式,加速训练:

  • 分块计算注意力
  • 减少内存读写
  • 不改变计算结果
  • 小模型:128-256
  • 中模型:512-768
  • 大模型:1024-2048
  • 通常 d/h = 64
  • 小模型:4-8 头
  • 大模型:16-32 头
  • 小模型:2-6 层
  • 中模型:6-12 层
  • 大模型:12-48 层
  • 通常为嵌入维度的 2-4 倍
  • FFN_dim = 4 × d
  • 注意力 dropout:0.1-0.2
  • 前馈网络 dropout:0.1-0.3
  • 残差 dropout:0.1-0.2
  • 预训练:1e-4 到 5e-4
  • 微调:1e-5 到 1e-4
  • 使用 warmup 和衰减

初始阶段使用较小的学习率,逐步增加到目标值:

ηt=η{max}min({t}{T{warmup}},1)\eta_t = \eta_\{\max\} \cdot \min\left(\frac\{t\}\{T_\{warmup\}\}, 1\right)

使用余弦衰减或线性衰减:

ηt=η{min}+{1}{2}(η{max}η{min})(1+cos({πt}{T{total}}))\eta_t = \eta_\{\min\} + \frac\{1\}\{2\}(\eta_\{\max\} - \eta_\{\min\})\left(1 + \cos\left(\frac\{\pi t\}\{T_\{total\}\}\right)\right)

防止梯度爆炸:

g \leftarrow \begin\{cases\} g & \|g\| \leq \theta \\ \theta \cdot \frac\{g\}\{\|g\|\} & \|g\| > \theta \end\{cases\}

使用 FP16 加速训练,减少显存占用。

  • Vaswani, A., et al. (2017). Attention is all you need. Advances in neural information processing systems, 30.
  • Rives, A., et al. (2021). Biological structure and function emerge from scaling unsupervised learning to 250 million protein sequences. Proceedings of the National Academy of Sciences, 118(15), e2016239118.
  • Ji, Y., et al. (2021). DNABERT: pre-trained Bidirectional Encoder Representations from Transformers model for DNA-language in genome. Bioinformatics, 37(15), 2112-2119.
  • Avsec, Ž., et al. (2021). Effective gene expression prediction from sequence by integrating long-range interactions. Nature methods, 18(11), 1423-1433.