Transformer 与自注意力机制
Transformer 通过自注意力机制直接建模序列中任意两个位置之间的关系,无需递归计算,是目前最强大的序列建模架构,是现代生物语言模型的基础。
- 自注意力机制捕获长程依赖
- 并行计算,训练效率高
- 是 ESM、DNABERT 等生物语言模型的核心
为什么需要全局注意力
Section titled “为什么需要全局注意力”CNN 擅长局部模式检测,RNN 擅长建模顺序依赖,但生物序列中还存在长程相互作用问题:
蛋白质结构的例子:形成稳定三维结构时,氨基酸残基之间可能存在远距离接触:
- 二硫键:半胱氨酸残基可能在序列上相距很远,但在空间上靠近
- 疏水核心:多个疏水残基需要在三维空间中聚集
- β-折叠片:来自不同位置的残基形成氢键配对
基因调控的例子:增强子(enhancer)可以在距离启动子数 kb 甚至数十 kb 的位置调控基因表达:
- 染色质环化使远距离元件物理接触
- 多个 TF 结合位点协同作用
- 这些相互作用与线性距离不成正比
这类问题的共同特点是:序列上远离的位置之间存在功能关联。CNN 的感受野有限,RNN 的梯度在长距离时会消失,而 Transformer 通过全局自注意力直接建模任意两个位置的关系。
Transformer 的核心思想
Section titled “Transformer 的核心思想”从递归到并行
Section titled “从递归到并行”RNN 顺序处理序列,第 t 步必须等待第 t-1 步完成。Transformer 的关键创新是完全并行化:
- 一次性读取整个序列
- 每个位置同时计算与所有其他位置的”注意力权重”
- 根据权重聚合信息
- 所有位置同时更新表示
想象一个会议:RNN 像依次发言(必须听完前面的人才能说),而 Transformer 像所有人同时阅读所有材料,然后同时写下自己的总结,基于对全部材料的理解。
自注意力机制
Section titled “自注意力机制”自注意力的核心问题是:对于序列中的每个位置,哪些其他位置与它最相关?
对每个位置:
- 生成 Query 向量(“我要找什么信息”)
- 所有位置生成 Key 向量(“我有什么信息”)
- 计算 Query-Key 匹配得分(相关性)
- 用得分加权聚合 Value 向量(获取相关信息)
这种机制使每个位置都能”看到”整个序列,并根据相关性选择性关注。
生物信息学应用场景
Section titled “生物信息学应用场景”| 应用 | 为什么需要 Transformer | 典型代表 |
|---|---|---|
| 蛋白质结构预测 | 氨基酸间长程接触决定 3D 结构 | AlphaFold2 (Evoformer) |
| 大规模预训练 | 并行训练效率高,可处理海量数据 | ESM、DNABERT、ProtBERT |
| 调控元件相互作用 | 增强子-启动子长距离调控 | Enformer |
| 变异效应预测 | 变异影响可能波及远处功能位点 | VEP 工具中的 Transformer 模型 |
| 单细胞多组学 | 整合不同模态的全局特征 | scGPT、scFoundation |
与 CNN/RNN 的根本区别
Section titled “与 CNN/RNN 的根本区别”| 特性 | CNN | RNN | Transformer |
|---|---|---|---|
| 信息流动 | 局部窗口滑动 | 顺序传递 | 全局直接连接 |
| 长程依赖 | 需多层堆叠 | 梯度衰减 | 直接建模 |
| 计算并行度 | 高 | 低(顺序) | 高 |
| 位置感知 | 通过卷积位置 | 通过递归顺序 | 通过位置编码 |
| 归纳偏置 | 局部性 | 时序性 | 无(全连接) |
Transformer 的强大来自于移除强归纳偏置:它不像 CNN 假设局部性,也不像 RNN 假设时序性,而是让数据自己学习哪些位置应该相互作用。代价是需要更多数据和计算来学习这些模式。
自注意力机制
Section titled “自注意力机制”给定输入序列 X = [x₁, x₂, …, x_L],每个 x_t ∈ ℝ^d。
首先通过线性变换得到 Query、Key、Value:
其中 W_Q, W_K, W_V ∈ ℝ^(d×d_k) 是可学习参数。
注意力权重计算:
其中:
- QK^T 计算所有位置对的相似度
- 除以 √d_k 进行缩放,防止梯度消失
- softmax 归一化得到注意力权重
- 与 V 加权求和得到输出
使用多组 Q、K、V 捕获不同类型的关系:
其中每个 head 的计算:
由于自注意力本身不包含位置信息,需要添加位置编码:
最终输入:
前馈网络(FFN)
Section titled “前馈网络(FFN)”每个位置独立应用前馈网络:
层归一化(Layer Norm)
Section titled “层归一化(Layer Norm)”完整 Transformer 层
Section titled “完整 Transformer 层”步骤 1:输入编码
Section titled “步骤 1:输入编码”将序列编码为嵌入表示,添加位置编码
步骤 2:多头自注意力
Section titled “步骤 2:多头自注意力”- 计算每个位置的 Query、Key、Value
- 计算注意力权重矩阵
- 加权求和得到注意力输出
- 拼接多头结果并线性变换
步骤 3:前馈网络
Section titled “步骤 3:前馈网络”对每个位置独立应用前馈网络
步骤 4:残差连接与层归一化
Section titled “步骤 4:残差连接与层归一化”应用残差连接和层归一化
步骤 5:堆叠多层
Section titled “步骤 5:堆叠多层”重复步骤 2-4 多次(通常 6-12 层)
步骤 6:输出层
Section titled “步骤 6:输出层”根据任务类型设计输出层:
- 分类:全局池化 + 线性层
- 序列标注:每个位置独立分类
- 回归:线性输出
考虑一个简化的蛋白质序列任务:使用 Transformer 预测氨基酸的二级结构。
序列片段:S = “ACDE”(长度 L = 4)
假设嵌入维度 d = 4,注意力头数 h = 2,每个头的维度 d_k = 2。
步骤 1:输入编码
Section titled “步骤 1:输入编码”假设嵌入表示(简化数值):
X = [ [1.0, 0.5, -0.3, 0.2], # A [0.3, 0.8, 0.1, -0.5], # C [-0.2, 0.4, 0.9, 0.1], # D [0.6, -0.1, 0.3, 0.7], # E]位置编码(简化):
PE = [ [0.0, 1.0, 0.0, 1.0], [0.5, 0.9, 0.5, 0.9], [0.8, 0.6, 0.8, 0.6], [1.0, 0.0, 1.0, 0.0],]添加位置编码:
X' = X + PE = [ [1.0, 1.5, -0.3, 1.2], [0.8, 1.7, 0.6, 0.4], [0.6, 1.0, 1.7, 0.7], [1.6, -0.1, 1.3, 0.7],]步骤 2:计算 Q、K、V
Section titled “步骤 2:计算 Q、K、V”假设权重矩阵(简化):
W_Q = [ [0.5, 0.3], [-0.2, 0.4], [0.1, -0.5], [0.3, 0.2],]计算 Q = X’ W_Q:
Q = [ [1.0×0.5 + 1.5×(-0.2) + (-0.3)×0.1 + 1.2×0.3, 1.0×0.3 + 1.5×0.4 + (-0.3)×(-0.5) + 1.2×0.2], ...]类似地计算 K 和 V。
步骤 3:计算注意力权重
Section titled “步骤 3:计算注意力权重”假设得到 Q、K、V:
Q = [ [0.5, 0.8], # A [0.3, 0.6], # C [0.7, 0.4], # D [0.2, 0.9], # E]
K = [ [0.4, 0.7], [0.6, 0.3], [0.2, 0.8], [0.5, 0.5],]
V = [ [0.9, 0.2], [0.3, 0.7], [0.6, 0.4], [0.1, 0.8],]计算注意力分数 QK^T:
QK^T = [ [0.5×0.4+0.8×0.7, 0.5×0.6+0.8×0.3, 0.5×0.2+0.8×0.8, 0.5×0.5+0.8×0.5], [0.3×0.4+0.6×0.7, 0.3×0.6+0.6×0.3, 0.3×0.2+0.6×0.8, 0.3×0.5+0.6×0.5], [0.7×0.4+0.4×0.7, 0.7×0.6+0.4×0.3, 0.7×0.2+0.4×0.8, 0.7×0.5+0.4×0.5], [0.2×0.4+0.9×0.7, 0.2×0.6+0.9×0.3, 0.2×0.2+0.9×0.8, 0.2×0.5+0.9×0.5],]
QK^T = [ [0.76, 0.54, 0.74, 0.65], [0.54, 0.36, 0.54, 0.45], [0.56, 0.54, 0.46, 0.55], [0.71, 0.39, 0.76, 0.55],]缩放(除以 √d_k = √2 ≈ 1.41):
Scaled = QK^T / 1.41 = [ [0.54, 0.38, 0.52, 0.46], [0.38, 0.26, 0.38, 0.32], [0.40, 0.38, 0.33, 0.39], [0.50, 0.28, 0.54, 0.39],]Softmax 归一化(以第一行为例):
完整注意力权重矩阵:
A = [ [0.27, 0.23, 0.26, 0.24], [0.28, 0.22, 0.28, 0.22], [0.26, 0.25, 0.24, 0.25], [0.28, 0.21, 0.29, 0.22],]步骤 4:加权求和
Section titled “步骤 4:加权求和”计算输出 Attention = A × V:
Output = [ [0.27×0.9+0.23×0.3+0.26×0.6+0.24×0.1, 0.27×0.2+0.23×0.7+0.26×0.4+0.24×0.8], ...]第一行:
步骤 5:多头拼接
Section titled “步骤 5:多头拼接”假设第二个头的输出为 [0.3, 0.7],拼接并线性变换得到最终输出。
这个简化的例子展示了 Transformer 如何通过自注意力机制计算每个位置对所有位置的依赖关系。
- 自注意力:O(L² × d),其中 L 是序列长度,d 是嵌入维度
- 前馈网络:O(L × d²)
- 单层总复杂度:O(L² × d + L × d²)
- 多层总复杂度:O(n_layers × (L² × d + L × d²))
与序列长度平方相关,对长序列计算成本高。
- 存储注意力矩阵:O(L²)
- 存储参数:O(n_layers × d²)
- 总空间复杂度:O(L² + n_layers × d²)
与真实工具或流程的连接
Section titled “与真实工具或流程的连接”Transformer 在生物信息学中的代表性应用:
- AlphaFold2(2021):基于 Transformer 的蛋白质结构预测
- ESM(2021):蛋白质语言模型
- DNABERT(2020):DNA 语言模型
- ProtBERT(2020):蛋白质语言模型
- Enformer(2021):基因表达预测
- HyenaDNA(2023):长序列 DNA 建模
这些工具的特点:
- 利用自注意力捕获长程依赖
- 大规模预训练后迁移学习
- 在多个下游任务上表现优异
Transformer 与 RNN/CNN 的对比
Section titled “Transformer 与 RNN/CNN 的对比”| 特性 | CNN | RNN | Transformer |
|---|---|---|---|
| 局部模式检测 | 强 | 弱 | 中 |
| 长程依赖 | 弱 | 中 | 强 |
| 并行计算 | 是 | 否 | 是 |
| 序列长度限制 | 无 | 梯度问题 | O(L²) 复杂度 |
| 参数量 | 少 | 中 | 多 |
| 可解释性 | 强 | 中 | 弱 |
算法变体与优化
Section titled “算法变体与优化”降低复杂度到 O(L × d):
其中 φ 是核函数。
限制每个位置只关注局部窗口:
其中 w 是窗口大小。
使用稀疏模式降低计算量:
- 全局注意力:少数位置关注所有位置
- 带状注意力:每个位置关注局部带状区域
- 随机注意力:随机选择注意力模式
相对位置编码
Section titled “相对位置编码”使用相对位置而非绝对位置:
其中 $r_{i-j}$ 是相对位置嵌入。
Rotary Position Embedding(RoPE)
Section titled “Rotary Position Embedding(RoPE)”通过旋转操作编码相对位置:
其中 是旋转矩阵。
Flash Attention
Section titled “Flash Attention”优化注意力计算的内存访问模式,加速训练:
- 分块计算注意力
- 减少内存读写
- 不改变计算结果
嵌入维度(d)
Section titled “嵌入维度(d)”- 小模型:128-256
- 中模型:512-768
- 大模型:1024-2048
注意力头数(h)
Section titled “注意力头数(h)”- 通常 d/h = 64
- 小模型:4-8 头
- 大模型:16-32 头
- 小模型:2-6 层
- 中模型:6-12 层
- 大模型:12-48 层
前馈网络维度
Section titled “前馈网络维度”- 通常为嵌入维度的 2-4 倍
- FFN_dim = 4 × d
Dropout
Section titled “Dropout”- 注意力 dropout:0.1-0.2
- 前馈网络 dropout:0.1-0.3
- 残差 dropout:0.1-0.2
- 预训练:1e-4 到 5e-4
- 微调:1e-5 到 1e-4
- 使用 warmup 和衰减
Warmup
Section titled “Warmup”初始阶段使用较小的学习率,逐步增加到目标值:
使用余弦衰减或线性衰减:
防止梯度爆炸:
g \leftarrow \begin\{cases\} g & \|g\| \leq \theta \\ \theta \cdot \frac\{g\}\{\|g\|\} & \|g\| > \theta \end\{cases\}混合精度训练
Section titled “混合精度训练”使用 FP16 加速训练,减少显存占用。
- Vaswani, A., et al. (2017). Attention is all you need. Advances in neural information processing systems, 30.
- Rives, A., et al. (2021). Biological structure and function emerge from scaling unsupervised learning to 250 million protein sequences. Proceedings of the National Academy of Sciences, 118(15), e2016239118.
- Ji, Y., et al. (2021). DNABERT: pre-trained Bidirectional Encoder Representations from Transformers model for DNA-language in genome. Bioinformatics, 37(15), 2112-2119.
- Avsec, Ž., et al. (2021). Effective gene expression prediction from sequence by integrating long-range interactions. Nature methods, 18(11), 1423-1433.