跳转到内容

循环神经网络(RNN)与 LSTM

快速概览

循环神经网络通过隐状态传递信息,能够建模序列中的长程依赖关系;LSTM 通过门控机制解决了标准 RNN 的梯度消失问题,是序列建模的经典架构。

  • 隐状态捕获序列上下文信息
  • LSTM 的门控机制实现选择性记忆和遗忘
  • 适合处理变长序列和时序依赖
所属板块 分析方向与案例

把基础对象与算法方法重新放回真实分析任务与工作流。

阅读目标 帮助建立阅读上下文

先判断这页与你当前问题的关系,再决定是否深入展开。

建议前置 先建立相关基础对象与方法直觉

建议先建立相关基础对象与方法直觉,再进入本页。

与 CNN 的局部扫描不同,许多生物序列问题需要考虑上下文才能做出准确判断:

基因预测的例子:判断一个位置是否属于编码区,需要:

  • 上游的阅读框信息(当前处于第几个密码子位置)
  • 下游的终止密码子信号
  • 序列整体的 GC 含量和碱基使用偏好

可变剪接的例子:识别剪接位点需要理解:

  • 5’ 剪接位点的 AG|GURAGU 共识序列
  • 3’ 剪接位点的 YYYYYYNC|AG 共识序列
  • 分支点(branch point)的位置
  • 这些信号之间的相对距离和相互作用

这类问题的共同特点是:每个位置的解释依赖于序列的上下文信息。RNN 正是为解决这类问题而设计的架构。

递归处理:RNN 维护一个隐状态(hidden state),在处理序列时不断更新:

  1. 读取当前输入(如第 t 个碱基)
  2. 结合前一时刻的隐状态
  3. 生成当前时刻的输出和新的隐状态
  4. 将新隐状态传递给下一时刻

想象你在阅读一条 DNA 序列,一边读一边在脑海中”记住”之前看到的内容。隐状态就是你的”记忆”,而 RNN 的递归机制决定了如何更新这个记忆。

标准 RNN 的局限与 LSTM 的解决方案

Section titled “标准 RNN 的局限与 LSTM 的解决方案”

标准 RNN 在处理长序列时会遇到梯度消失问题:随着时间步的增加,隐状态对早期输入的”记忆”逐渐消失。这在生物序列中是个严重问题,因为:

  • 基因可能长达数千个碱基
  • 远距离调控元件可能位于数 kb 之外
  • 剪接信号需要记住分支点的位置

LSTM(Long Short-Term Memory)通过三个门控机制解决梯度消失:

门控功能类比
遗忘门决定从细胞状态中丢弃什么信息”忘记不重要的旧信息”
输入门决定存储什么新信息”记录重要的新信息”
输出门决定输出什么信息”基于当前记忆做决定”

这些门控使 LSTM 能够选择性记忆,在数千个时间步后仍能保留关键信息。

应用为什么需要 RNN/LSTM典型输入/输出
基因预测编码区识别需要阅读框和上下文DNA 序列 → 每个位置的标签(编码/内含子/UTR)
可变剪接预测剪接位点识别需要理解多种信号的组合内含子-外显子边界 → 剪接位点概率
蛋白质二级结构结构状态依赖于局部氨基酸环境氨基酸序列 → 每个残基的结构状态(螺旋/折叠/卷曲)
单细胞轨迹分析细胞分化是连续动态过程基因表达时间序列 → 细胞状态/分化方向
序列生成生成合理的生物序列需要维持序列约束起始序列 → 生成的后续序列

给定序列 X=[x1,x2,...,xT]X = [x_1, x_2, ..., x_T],每个 xtRKx_t \in \mathbb{R}^{K}

隐状态更新公式:

ht=σ(W{hh}h{t1}+W{xh}xt+bh)h_t = \sigma(W_\{hh\} h_\{t-1\} + W_\{xh\} x_t + b_h)

输出公式:

yt=σ(W{hy}ht+by)y_t = \sigma'(W_\{hy\} h_t + b_y)

其中:

  • htRHh_t \in \mathbb{R}^{H} 是时间步 t 的隐状态
  • WhhRH×HW_{hh} \in \mathbb{R}^{H\times H} 是隐状态到隐状态的权重
  • WxhRH×KW_{xh} \in \mathbb{R}^{H\times K} 是输入到隐状态的权重
  • WhyRO×HW_{hy} \in \mathbb{R}^{O\times H} 是隐状态到输出的权重
  • σ 和 σ’ 是激活函数

LSTM 通过三个门控机制控制信息流:

ft=σ(Wf[h{t1},xt]+bf)f_t = \sigma(W_f \cdot [h_\{t-1\}, x_t] + b_f)

决定从细胞状态中遗忘多少信息。

it=σ(Wi[h{t1},xt]+bi)i_t = \sigma(W_i \cdot [h_\{t-1\}, x_t] + b_i) {~C}t=tanh(WC[h{t1},xt]+bC)\tilde\{C\}_t = \tanh(W_C \cdot [h_\{t-1\}, x_t] + b_C)

决定向细胞状态中添加多少新信息。

Ct=ftC{t1}+it{~C}tC_t = f_t \odot C_\{t-1\} + i_t \odot \tilde\{C\}_t

其中 ⊙ 表示逐元素乘法。

ot=σ(Wo[h{t1},xt]+bo)o_t = \sigma(W_o \cdot [h_\{t-1\}, x_t] + b_o) ht=ottanh(Ct)h_t = o_t \odot \tanh(C_t)

决定输出多少信息作为新的隐状态。

GRU 是 LSTM 的简化版本,只有两个门:

zt=σ(Wz[h{t1},xt]+bz)z_t = \sigma(W_z \cdot [h_\{t-1\}, x_t] + b_z) rt=σ(Wr[h{t1},xt]+br)r_t = \sigma(W_r \cdot [h_\{t-1\}, x_t] + b_r) {~h}t=tanh(Wh[rth{t1},xt]+bh)\tilde\{h\}_t = \tanh(W_h \cdot [r_t \odot h_\{t-1\}, x_t] + b_h) ht=(1zt)h{t1}+zt{~h}th_t = (1 - z_t) \odot h_\{t-1\} + z_t \odot \tilde\{h\}_t

标准 RNN 的核心是隐状态的递归更新:在每个时间步,当前输入与前一时刻的隐状态通过可学习权重矩阵结合,经激活函数(通常为 tanh 或 ReLU)产生新的隐状态。这种结构使 RNN 能够捕获序列中的上下文信息,但也带来了梯度消失/爆炸问题:当序列较长时,反向传播通过时间展开后,梯度会指数级衰减或增长,导致模型难以学习长程依赖关系。

初始化隐状态 h₀(通常为零向量或随机初始化)

对每个时间步 t = 1, 2, …, T:

  1. 计算 ht=σ(Whhht1+Wxhxt+bh)h_t = \sigma(W_{hh} h_{t-1} + W_{xh} x_t + b_h)
  2. 计算 yt=σ(Whyht+by)y_t = \sigma'(W_{hy} h_t + b_y)

根据任务类型计算损失(如交叉熵、MSE)

通过时间反向传播(BPTT)更新所有参数

LSTM 通过引入**细胞状态(cell state)**和三个门控机制解决标准 RNN 的梯度消失问题。遗忘门决定丢弃哪些旧信息,输入门决定存储哪些新信息,输出门决定输出哪些信息。细胞状态作为信息传输的主干道,可以几乎无损地跨多个时间步传递信息,从而使模型能够学习长程依赖。GRU 是 LSTM 的简化版本,合并了遗忘门和输入门为单一更新门,参数更少、计算更快,在许多任务上表现相当。

初始化隐状态 h0h_0 和细胞状态 C0C_0

对每个时间步 t = 1, 2, …, T:

  1. 计算遗忘门 ftf_t
  2. 计算输入门 iti_t 和候选细胞状态 C~t\tilde{C}_t
  3. 更新细胞状态 CtC_t
  4. 计算输出门 oto_t
  5. 更新隐状态 hth_t
  6. 计算输出 yty_t

根据任务类型计算损失

通过 BPTT 更新所有参数

考虑一个简化的 DNA 序列分类任务:使用 LSTM 预测序列是否为启动子。

序列:S = “TATAAA”(长度 T = 6)

假设使用 one-hot 编码,隐藏层维度 H = 4。

h0=[0,0,0,0]h_0 = [0, 0, 0, 0] C0=[0,0,0,0]C_0 = [0, 0, 0, 0]

步骤 2:时间步 t=1(输入 x₁ = T)

Section titled “步骤 2:时间步 t=1(输入 x₁ = T)”

假设权重和偏置如下(简化数值):

遗忘门:

f1=σ(Wf[h0,x1]+bf)=σ([0.3,0.2,0.1,0.1])=[0.57,0.45,0.52,0.48]f_1 = \sigma(W_f \cdot [h_0, x_1] + b_f) = \sigma([0.3, -0.2, 0.1, -0.1]) = [0.57, 0.45, 0.52, 0.48]

输入门:

i1=σ(Wi[h0,x1]+bi)=σ([0.5,0.3,0.1,0.2])=[0.62,0.57,0.48,0.55]i_1 = \sigma(W_i \cdot [h_0, x_1] + b_i) = \sigma([0.5, 0.3, -0.1, 0.2]) = [0.62, 0.57, 0.48, 0.55]

候选细胞状态:

{~C}1=tanh(WC[h0,x1]+bC)=tanh([0.8,0.5,0.6,0.3])=[0.66,0.46,0.54,0.29]\tilde\{C\}_1 = \tanh(W_C \cdot [h_0, x_1] + b_C) = \tanh([0.8, -0.5, 0.6, -0.3]) = [0.66, -0.46, 0.54, -0.29]

细胞状态更新:

C1=f1C0+i1{~C}1=[0,0,0,0]+[0.41,0.26,0.26,0.16]=[0.41,0.26,0.26,0.16]C_1 = f_1 \odot C_0 + i_1 \odot \tilde\{C\}_1 = [0, 0, 0, 0] + [0.41, -0.26, 0.26, -0.16] = [0.41, -0.26, 0.26, -0.16]

输出门:

o1=σ(Wo[h0,x1]+bo)=σ([0.4,0.2,0.3,0.1])=[0.60,0.55,0.43,0.52]o_1 = \sigma(W_o \cdot [h_0, x_1] + b_o) = \sigma([0.4, 0.2, -0.3, 0.1]) = [0.60, 0.55, 0.43, 0.52]

隐状态更新:

h1=o1tanh(C1)=[0.60,0.55,0.43,0.52]tanh([0.41,0.26,0.26,0.16])h_1 = o_1 \odot \tanh(C_1) = [0.60, 0.55, 0.43, 0.52] \odot \tanh([0.41, -0.26, 0.26, -0.16]) h1=[0.60,0.55,0.43,0.52][0.39,0.25,0.25,0.16]=[0.23,0.14,0.11,0.08]h_1 = [0.60, 0.55, 0.43, 0.52] \odot [0.39, -0.25, 0.25, -0.16] = [0.23, -0.14, 0.11, -0.08]

类似地处理 t=2, 3, 4, 5, 6,每个时间步更新隐状态和细胞状态。

使用最后一个隐状态 h6h_6 进行预测:

y=σ(W{hy}h6+by)y = \sigma(W_\{hy\} \cdot h_6 + b_y)

假设 h6=[0.8,0.6,0.3,0.4]h_6 = [0.8, 0.6, -0.3, 0.4]Why=[0.5,0.3,0.2,0.4]W_{hy} = [0.5, 0.3, -0.2, 0.4]by=0.3b_y = -0.3

y=σ(0.5×0.8+0.3×0.60.2×(0.3)+0.4×0.40.3)y = \sigma(0.5 \times 0.8 + 0.3 \times 0.6 - 0.2 \times (-0.3) + 0.4 \times 0.4 - 0.3) y=σ(0.4+0.18+0.06+0.160.3)=σ(0.5)0.62y = \sigma(0.4 + 0.18 + 0.06 + 0.16 - 0.3) = \sigma(0.5) ≈ 0.62

由于 y > 0.5,预测该序列为启动子。

这个简化的例子展示了 LSTM 如何通过门控机制逐步更新细胞状态和隐状态,捕获序列中的上下文信息。

  • 单步 RNN:O(H² + HK + HO)
  • 单步 LSTM:O(4H² + 4HK + HO)(4 个门)
  • 完整序列:O(T × (H² + HK + HO)) 对于 RNN
  • 完整序列:O(T × (4H² + 4HK + HO)) 对于 LSTM

与序列长度 T 线性相关,但每个时间步的计算成本高于 CNN。

  • 存储参数:O(H² + HK + HO) 对于 RNN
  • 存储参数:O(4H² + 4HK + HO) 对于 LSTM
  • 存储 BPTT 梯度:O(T × H)

RNN/LSTM 在生物信息学中的代表性应用:

  • DeepGene(2015):基于 LSTM 的基因预测
  • SpliceAI(2019):基于深度学习的剪接位点预测
  • DeepPromoter(2018):启动子识别
  • Basset(2016):结合 CNN 和 RNN 的染色质可及性预测
  • scVelo(2020):单细胞 RNA 速度分析(使用类似 RNN 的动力学模型)

这些工具的特点:

  • 利用长程依赖信息
  • 适合变长序列
  • 可以结合 CNN(如 DanQ)
维度 标准 RNN LSTM
门控机制 遗忘门、输入门、输出门
梯度消失 严重 缓解
长程依赖 难以学习 能够学习
参数量 多(约 4 倍)
计算速度 较慢
训练难度 较难 较容易

同时处理正向和反向序列:

{h}t={RNN}(xt,{h}{t1})\overrightarrow\{h\}_t = \text\{RNN\}(x_t, \overrightarrow\{h\}_\{t-1\}) {h}t={RNN}(xt,{h}{t+1})\overleftarrow\{h\}_t = \text\{RNN\}(x_t, \overleftarrow\{h\}_\{t+1\}) ht=[{h}t,{h}t]h_t = [\overrightarrow\{h\}_t, \overleftarrow\{h\}_t]

堆叠多个 RNN 层,学习更复杂的特征:

ht{(l)}={RNN}(ht{(l1)},h{t1}{(l)})h_t^\{(l)\} = \text\{RNN\}(h_t^\{(l-1)\}, h_\{t-1\}^\{(l)\})

结合注意力机制,动态加权不同位置的隐状态:

ct={i=1}Tα{ti}hic_t = \sum_\{i=1\}^T \alpha_\{ti\} h_i

其中 αti\alpha_{ti} 是注意力权重。

防止梯度爆炸:

g \leftarrow \begin\{cases\} g & \|g\| \leq \theta \\ \theta \cdot \frac\{g\}\{\|g\|\} & \|g\| > \theta \end\{cases\}

加速训练,稳定梯度:

{~h}t={htμ}{σ}γ+β\tilde\{h\}_t = \frac\{h_t - \mu\}\{\sigma\} \odot \gamma + \beta
  • 小数据集:32-128
  • 大数据集:128-512
  • 1-2 层通常足够
  • 复杂任务可尝试 2-4 层
  • RNN 输出层:0.2-0.5
  • LSTM/GRU 门控:通常不使用 dropout
  • Adam 优化器:1e-3 到 1e-4
  • SGD 优化器:1e-2 到 1e-3
  • 过长序列可能导致梯度问题
  • 可考虑截断或分段处理
  • Hochreiter, S., & Schmidhuber, J. (1997). Long short-term memory. Neural computation, 9(8), 1735-1780.
  • Cho, K., et al. (2014). Learning phrase representations using RNN encoder-decoder for statistical machine translation. arXiv preprint arXiv:1406.1078.
  • Jaganathan, K., et al. (2019). Predicting splicing from primary sequence with deep learning. Nature, 571(7763), 115-119.
  • Berg, J. A., et al. (2019). DeepPromoter: Promoter prediction using deep learning. BMC bioinformatics, 20(1), 1-13.