PLAS
介绍
我们提出了PLAS(Pluggable Lightweight Attention for Sparsity),这是对 MoBA 的改进。具体来说,我们采用了受 MoE 启发的结构,将 KV 划分为多个块,并引入了一个可学习的 MLP 层来自适应地选择重要块。PLAS 可以直接在训练后应用,此时只有 MLP 权重可学习,而原始模型权重保持不变。
与 NSA/MoBA 相比,我们的 PLAS 具有更高的可扩展性和可插拔性。它无需修改传统的注意力架构,也无需在训练前或训练后干扰模型权重训练。最终阶段只需对 MLP 层进行少量训练即可实现几乎无损的准确率。由于 NSA/MoBA 会更新整个模型权重,因此不可避免地会影响短文本的性能——即使它在输入长度小于 BlockSize × Top-K 时会自动切换到完全注意力机制。相比之下,我们的 PLAS 在短文本场景下可以实现与原始模型真正等同的完全注意力机制。
在训练效率方面,由于仅需更新 MLP 权重,训练成本极低。在推理性能方面,当输入长度为 128K、Block Size = 128、Top-K = 55 时,PLAS 相比 Flash Attention 3 实现了386% 的加速。
方法
训练
借鉴 NSA 和 MoBA 的方法,我们将键值对 (KV) 划分为多个块。在预填充和解码阶段,我们不再对所有键值进行注意力计算,而是动态地为每个查询 token 选择注意力得分最高的前 K 个块,从而实现高效的稀疏注意力计算。

-
Attention Gate Module: 如上图所示,为了以较低的计算开销估计每个块的重要性,我们设计了一个轻量级的注意力门模块。该模块首先通过一个 MLP 层压缩每个 K 个块,生成一个具有代表性的低维表示:$K_c^T=W_{kp}K^T$,其中 $W_{kp}$ 表示 MLP 层的权重。与直接应用均值池化相比,可学习的 MLP 可以更有效地捕捉不同 token 之间的语义关系和重要性分布,从而提供每个块的精细表示。在获得压缩表示 $K_c$ 之后,通过以下公式估计每个查询 token 相对于每个块的重要性:$Softmax(Q\cdot K_c^T)$。为了增强 MLP 层的判别能力,我们使用一维最大池化后的完整注意力结果 $1DMaxPooling(Softmax(Q \cdot K^T))$ 作为 ground truth。通过最小化两者之间的分布差异,引导 MLP 层学习更符合真实注意力分布的特征表示。
-
Training Data: 得益于模型架构和训练范式的高效性,我们的方法仅使用 10 亿个 token 进行训练,便实现了近乎无损的精度。训练数据源自内部构建的包含长文本和短文本的混合语料库,从而增强了模块对不同序列长度的适应性。
-
Other: 我们观察到,最终的解码层对模型整体准确率有显著影响。因此,在训练过程中,我们将该层排除在稀疏注意力计算之外,并在推理过程中将其恢复为完全注意力。
推理优化
在稀疏注意力计算过程中,每个查询 token 可能会动态选择不同的 KV 块,导致 HBM 的内存访问模式非常不规则。简单地对每个查询 token 进行单独处理是可行的,但这会导致计算粒度过细,无法充分利用张量核,从而显著降低 GPU 的计算效率。

为了优化预填充和解码阶段的性能,我们设计了一种特殊的联合策略来适应各自的特点:
-
Prefill Token Union: 我们观察到相邻的查询标记倾向于选择相似的关键块。利用这种局部性,我们取连续 128 个查询标记选择的关键块的并集,并联合计算这些标记的稀疏注意力机制。
-
Decode Head Union: 鉴于 GQA 在现代模型中的广泛应用,我们发现同一组内的不同查询头经常选择重叠的关键块。因此,我们将同一组内所有查询头选择的关键块合并为一个统一的集合,并联合计算稀疏注意力机制。这种方式也减少了内存访问开销,并进一步提高了解码效率。
-
Top-K Selection: 传统的 Top-k 算法基于排序或直接调用 Cub 库,会带来显著的运行时开销。为了缓解这个问题,我们实现了一个基于二分查找的近似 Top-k 选择算法,该算法在保持准确率的同时显著降低了延迟,最终实现了性能的显著提升。
评估
实验
我们在 LongBenchV2 和 Ruler(上下文长度分别为 32K、64K 和 128K)上评估了全注意力和稀疏注意力的精度。
Model | Precision | |||||||
FullAttention | SparseAttention | |||||||
LongBenchV2 | Ruler | LongBenchV2 | Ruler | |||||
32K | 64K | 128K | 32K | 64K | 128K | |||
ERNIE-4.5-21B-A3B | 31.48 | 76.74 | 56.40 | 25.48 | 31.45 | 75.93 | 55.38 | 25.05 |
ERNIE-4.5-300B-A47B | 41.02 | 94.70 | 83.56 | 58.18 | 41.05 | 94.50 | 82.32 | 57.85 |
性能
我们从 InfiniteBench 中选择了一个子集 (longbook_sum_eng) 作为性能评估数据集。对于长度超过 128K 的输入,我们截断序列,保留前 64K 和后 64K 个 token。
QPS | Decode Speed (token/s) | Time to First token(s) | Time per Output Token(ms) | End-to-End Latency(s) | Mean Input Length |
Mean Output Length | ||
ERNIE-4.5-21B-A3B | FullAttention | 0.101 | 13.32 | 8.082 | 87.05 | 61.400 | 113182.32 | 627.76 |
SparseAttention | 0.150(+48%) | 18.12(+36%) | 5.466(-48%) | 66.35(-31%) | 42.157(-46%) | 113182.32 | 590.23 | |
ERNIE-4.5-300B-A47B | FullAttention | 0.066 | 5.07 | 13.812 | 206.70 | 164.704 | 113182.32 | 725.97 |
SparseAttention | 0.081(+23%) | 6.75(+33%) | 10.584(-30%) | 154.84(-34%) | 132.745(-24%) | 113182.32 | 748.25 |
使用方式
export FD_ATTENTION_BACKEND="MOBA_ATTN"
python -m fastdeploy.entrypoints.openai.api_server
--model baidu/ERNIE-4.5-300B-A47B-Paddle \
--port 8188 \
--tensor-parallel-size 4 \
--quantization wint4 \
--enable-chunked-prefill \
--max-num-batched-tokens 8192 \
--max-model-len 131072 \
--max-num-seqs 32 \
--moba-attention-config '{"moba_encoder_top_k_left": 50, "moba_encoder_top_k_right": 60, "moba_decoder_top_k_left": 100, "moba_decoder_top_k_right": 120}'
Note: 如果启用了稀疏注意力机制,系统将自动从权重目录中的moba_mlp_weight.safetensors
文件加载 MLP 权重。如果未找到 MLP 权重文件,则将对关键表示应用均值池化
Parameter Description:
FD_ATTENTION_BACKEND="MOBA_ATTN"
启用 MOBA sparse attention.moba_encoder_top_k_left=50, moba_encoder_top_k_right=60
表示当encoder时,top-k的范围在50到60之间。moba_decoder_top_k_left=100, moba_decoder_top_k_right=120
表示当decoder时,top-k的范围在100到120之间。