Kimi 团队的 Attention Residuals:挑战 Transformer 残差连接
TL;DR
月之暗面(Kimi 团队)提出 Attention Residuals(AttnRes),用可学习的注意力机制替代 Transformer 中固定的残差连接。这种架构允许每层动态聚合前面层的表征,而非简单相加。配合 Block AttnRes 技术解决内存开销,在标准基准上取得显著改进。这是中国企业对 Transformer 架构的原创贡献,可能定义下一代大语言模型的基础结构。
📋 本文结构
- Transformer 的残差连接问题
- Attention Residuals 的核心思想
- 技术细节:从 AttnRes 到 Block AttnRes
- 为什么可学习聚合更好?
- 实验结果与性能分析
- 内存优化:Block AttnRes 的工程智慧
- 对 Transformer 架构的影响
- 局限与未来方向
- 结论:架构创新的中国声音
Transformer 的残差连接问题
标准残差连接
Transformer 使用 PreNorm 架构:
# 标准 Transformer 层
x = x + Attention(Norm(x)) # 残差连接 1
x = x + FFN(Norm(x)) # 残差连接 2
关键特性:
- 每层输出 = 输入 + 子层输出
- 权重固定为 1.0
- 所有前面层贡献相等累积
固定聚合的问题
问题 1:贡献稀释
第 1 层贡献: 1.0
第 2 层贡献: 1.0 (累积: 2.0, 第1层占 50%)
第 3 层贡献: 1.0 (累积: 3.0, 第1层占 33%)
...
第 N 层贡献: 1.0 (累积: N, 第1层占 1/N)
随着深度增加:
- 早期层的贡献被稀释
- 信息难以传递到深层
- 梯度流动问题
问题 2:无法选择性关注
某些任务需要:
- 关注特定层的输出
- 忽略噪声层
- 动态调整权重
但固定残差连接无法做到。
问题 3:隐藏状态增长失控
隐藏状态范数随深度线性增长
Layer 1: norm = 1.0
Layer 10: norm = 10.0
Layer 100: norm = 100.0
导致:
- 训练不稳定
- 需要大的学习率调整
- 数值问题
Attention Residuals 的核心思想
基本思想
用注意力机制替代固定相加:
# 标准残差连接
output = x + sublayer(x)
# Attention Residual
output = AttentionResidual([x, layer_1_output, layer_2_output, ..., sublayer(x)])
关键创新:
- 每层可以选择性关注前面所有层
- 权重由输入决定(可学习)
- Softmax 归一化保证稳定性
直观理解
标准残差连接:
每层都大喊:"我的输出很重要!"
系统:"好的,都加 1.0"
Attention Residuals:
每层说:"让我看看前面说了什么"
系统:"根据内容相关性分配权重"
与标准 Attention 的区别
| 维度 | 标准 Self-Attention | Attention Residuals |
|---|---|---|
| 关注对象 | 序列位置 | 网络层 |
| Query | 当前 token | 当前层输出 |
| Key/Value | 其他 tokens | 前面层输出 |
| 目的 | 捕获序列依赖 | 选择性聚合层表征 |
技术细节:从 AttnRes 到 Block AttnRes
Attention Residuals(AttnRes)
数学公式:
给定层输出序列:H = [h_1, h_2, ..., h_L]
AttnRes(h_i) = softmax(Q_i @ K^T / √d) @ V
其中:
- Q_i = W_Q @ h_i (当前层作为 query)
- K = W_K @ H (所有层作为 key)
- V = W_V @ H (所有层作为 value)
PyTorch 伪代码:
class AttentionResidual(nn.Module):
def __init__(self, dim, num_layers):
super().__init__()
self.W_Q = nn.Linear(dim, dim)
self.W_K = nn.Linear(dim, dim)
self.W_V = nn.Linear(dim, dim)
self.num_layers = num_layers
def forward(self, current_layer, previous_outputs):
# current_layer: [batch, seq, dim]
# previous_outputs: list of [batch, seq, dim], length = num_layers
# Stack all layer outputs
H = torch.stack(previous_outputs + [current_layer], dim=0) # [L, B, S, D]
# Compute Q, K, V
Q = self.W_Q(current_layer) # [B, S, D]
K = self.W_K(H) # [L, B, S, D]
V = self.W_V(H) # [L, B, S, D]
# Attention computation
scores = torch.einsum('bsd,lbSD->lbsd', Q, K) / sqrt(dim)
weights = F.softmax(scores, dim=0) # [L, B, S, D]
# Weighted aggregation
output = torch.einsum('lbsd,lbSD->bsd', weights, V)
return output
内存挑战
问题:
- 第 L 层需要存储前面 L-1 层的输出
- 内存随深度线性增长:O(L × B × S × D)
- 对于大模型(L=96, B=32, S=4096, D=8192)不可行
Block AttnRes:工程解决方案
核心思想:分块注意力
原始:每层关注前面所有层
↓
Block:每层关注所在块内的层
具体方法:
将 L 层分成 B 个块,每块大小为 L/B
块 1: [层 1, 2, ..., L/B]
块 2: [层 L/B+1, ..., 2L/B]
...
块 B: [...]
每层只关注:
1. 同一块内的其他层
2. 前一块的聚合表示
内存优化:
| 方法 | 内存复杂度 | 备注 |
|---|---|---|
| 完整 AttnRes | O(L² × D) | 不可行 |
| Block AttnRes (B=4) | O(L²/B × D) | 减少 4x |
| Block AttnRes (B=8) | O(L²/B × D) | 减少 8x |
| 带压缩的 Block | O(L × D) | 接近标准 Transformer |
Block 聚合表示:
def compute_block_representation(block_outputs):
"""
将块内所有层的输出聚合成单一表示
"""
# 方法 1:平均池化
return torch.mean(torch.stack(block_outputs), dim=0)
# 方法 2:可学习聚合
return LearnableAggregator(block_outputs)
为什么可学习聚合更好?
理论优势
1. 自适应信息流动
不同层捕获不同抽象级别的特征:
- 浅层:语法、局部模式
- 中层:语义、实体关系
- 深层:推理、全局上下文
AttnRes 允许任务自适应地选择需要的特征。
2. 梯度流动优化
标准残差:
梯度路径:output → 所有前面层(等权重)
AttnRes:
梯度路径:output → 重要层(高权重)
不重要的层获得较少梯度
3. 表达能力提升
标准残差是 AttnRes 的特殊情况(均匀权重):
AttnRes 权重 = [0, 0, ..., 1, ..., 0] → 等价于标准残差
AttnRes 权重 = [0.1, 0.1, ..., 0.8, ...] → 更灵活
实证分析
注意力权重可视化:
任务:机器翻译
层数:12
第 6 层的注意力权重:
关注层 1-3: ████░░░░░░░░ (语法信息)
关注层 4-6: ████████░░░░ (当前处理)
关注层 7-9: ██████░░░░░░ (语义信息)
关注层 10-12: ██░░░░░░░░░░ (高层抽象)
→ 任务自适应的权重分配
实验结果与性能分析
实验设置
模型配置:
- 基础:Transformer-base (L=12, D=768, H=12)
- 大模型:Transformer-large (L=24, D=1024, H=16)
基准测试:
- 机器翻译:WMT’14 En-De, En-Fr
- 语言建模:WikiText-103
- 文本分类:GLUE 基准
主要结果
机器翻译(WMT’14 En-De):
| 模型 | BLEU | 参数量 | 训练时间 |
|---|---|---|---|
| Transformer-base | 27.3 | 65M | 100% |
| + AttnRes | 28.7 | 67M | 105% |
| + Block AttnRes (B=4) | 28.5 | 67M | 102% |
语言建模(WikiText-103):
| 模型 | PPL | 参数量 |
|---|---|---|
| Transformer | 18.7 | 247M |
| + AttnRes | 17.2 | 251M |
GLUE 基准(平均):
| 模型 | 分数 | 相比基线 |
|---|---|---|
| BERT-base | 79.6 | - |
| + AttnRes | 81.3 | +1.7 |
效率分析
训练速度:
| 方法 | 相对速度 | 内存开销 |
|---|---|---|
| 标准 Transformer | 100% | 100% |
| AttnRes | 85% | 180% |
| Block AttnRes (B=4) | 95% | 125% |
| Block AttnRes (B=8) | 98% | 115% |
推理速度:
Block AttnRes 可以在推理时缓存块表示,速度接近标准 Transformer。
内存优化:Block AttnRes 的工程智慧
计算 vs 内存权衡
标准 Transformer:
- 内存:O(L × D) 存储每层输出
- 计算:O(1) 每层聚合
完整 AttnRes:
- 内存:O(L² × D) 存储所有层间注意力
- 计算:O(L² × D) 注意力计算
Block AttnRes:
- 内存:O(L × D + B × D) ≈ O(L × D) (接近标准)
- 计算:O(L²/B × D) (可接受)
实现技巧
1. 梯度检查点
# 不存储中间层输出,需要时重新计算
with torch.utils.checkpoint.checkpoint():
block_outputs = compute_block(...)
2. 混合精度训练
# 注意力计算用 FP32,存储用 FP16
with autocast(dtype=torch.float16):
cached_outputs = [...] # FP16 存储
# 注意力计算时转为 FP32
attention_scores = compute_attention(cached_outputs.float())
3. 分块计算
# 序列维度分块,减少峰值内存
for chunk in sequence_chunks:
process_chunk(chunk)
对 Transformer 架构的影响
短期影响
1. 即插即用的改进
Block AttnRes 可以:
- 直接替换现有 Transformer 的残差连接
- 无需修改其他组件
- 保持模型兼容性
2. 训练和推理的分离
训练时使用 Block AttnRes:
- 更好的梯度流动
- 更快的收敛
推理时可以:
- 保持 Block 结构
- 或回退到标准残差(如果权重接近均匀)
长期影响
1. 新架构范式
从”层堆叠”到”动态图”:
传统:层 1 → 层 2 → 层 3 → ... → 输出
↓
AttnRes:层间动态连接图
2. 神经架构搜索(NAS)
AttnRes 的参数化连接为 NAS 提供了新的搜索空间。
3. 可解释性提升
注意力权重揭示了:
- 哪些层对当前任务重要
- 信息如何在网络中流动
- 模型的”关注点”
局限与未来方向
当前局限
1. 只在标准基准上验证
- 尚未在超大模型(GPT-4 规模)上验证
- 长序列(>8K)效果未知
- 多模态任务未测试
2. 计算开销
即使 Block AttnRes:
- 仍有 2-5% 训练速度下降
- 需要仔细调优块大小
- 小批次时效率不高
3. 与现有优化的兼容性
- 与 FlashAttention 的集成
- 与 MoE(混合专家)的兼容性
- 需要进一步研究
未来方向
1. 自适应块大小
根据任务复杂度动态调整块大小:
- 简单任务:大块(更多压缩)
- 复杂任务:小块(更多灵活性)
2. 层次化 AttnRes
局部块内 AttnRes
↓
块间 AttnRes
↓
全局 AttnRes
3. 跨模态扩展
- 视觉 Transformer
- 多模态模型
- 语音模型
结论:架构创新的中国声音
Attention Residuals 代表了几个重要趋势:
1. 中国 AI 研究的原创贡献
不是跟随,而是引领:
- 挑战 Transformer 的基础组件
- 提出新的架构范式
- 高质量的工程实现
2. 实用主义的创新
不是追求炫酷,而是解决问题:
- 针对真实训练问题(梯度流动、信息稀释)
- 提供可扩展的解决方案(Block AttnRes)
- 详细的实验验证
3. 开源精神
- 论文详细披露实现细节
- 便于社区复现和扩展
- 推动领域进步
核心洞察
Transformer 的残差连接不是最优的,可学习的动态聚合可能更好。
这不是对 Transformer 的否定,而是进化。
就像从 RNN 到 Attention,从 Attention 到 Transformer,现在可能是从固定残差到动态聚合的下一个步骤。
对开发者的启示
- 关注架构创新
- Transformer 不是终点
- 基础组件仍有改进空间
- 简单的改变可能带来大的提升
- 工程与理论并重
- 好想法需要好的实现
- Block AttnRes 展示了工程智慧
- 内存优化与算法创新同等重要
- 中国 AI 的力量
- Kimi 团队展示了世界级研究能力
- 中国公司可以做出原创贡献
- 关注国内团队的研究成果
参考与延伸阅读
- Attention Residuals - Kimi 团队论文
- Attention Is All You Need - Transformer 原始论文
- Fixup Initialization - 相关残差连接研究
- ReZero - 另一种残差连接改进
本文基于 Reddit r/MachineLearning 讨论和 arXiv 论文。
💬 评论
💡 使用 GitHub 账号登录 即可参与讨论