rCM: Large Scale Diffusion Distillation via Score-Regularized Continuous-Time Consistency
Authors: Kaiwen Zheng, Yuji Wang, Qianli Ma, Huayu Chen, Jintao Zhang, Yogesh Balaji, Jianfei Chen, Ming-Yu Liu, Jun Zhu, Qinsheng Zhang Affiliations: Tsinghua University, NVIDIA arXiv: 2510.08431 Venue: ICLR 2026
1. Motivation (研究动机)
1.1 问题背景
连续时间一致性模型 (sCM, MeanFlow) 在理论上优雅, 但在大规模应用中面临三大挑战:
- 质量瓶颈: sCM 的 forward divergence (mode-covering) 特性导致细节生成退化 (如文字渲染失败、视频时序不一致)
- 误差累积: JVP 自反馈信号在 BF16 精度下数值脆弱, 随时间 增大误差被放大; 当 时, teacher 监督消失, 学习动态被 JVP 主导
- 工程限制: PyTorch 原生
torch.func.jvp不兼容 FlashAttention-2、FSDP、Context Parallelism, 无法训练 10B+ 模型
1.2 现有方法的局限
论文未详细说明。
2. Idea (核心思想)
2.1 核心贡献
| 贡献 | 描述 |
|---|---|
| rCM 框架 | 联合 forward divergence (sCM) + reverse divergence (DMD) 的蒸馏方法, 即可平衡质量与多样性 |
| FlashAttention-2 JVP Kernel | 基于 Triton 实现的 JVP 核, 支持 self/cross attention, 兼容 FSDP/CP |
| 稳定时间导数计算 | Semi-continuous time (有限差分) 和 High-precision time (FP32 时间嵌入) 两种策略 |
| 大规模验证 | 首次将连续时间一致性蒸馏扩展到 Cosmos-Predict2 (14B) T2I 和 Wan2.1 (14B) T2V, 1 |
2.2 rCM: 联合目标
核心洞察: sCM (forward divergence) 提供 mode-covering (高多样性但低质量), DMD (reverse divergence) 提供 mode-seeking (高质量但低多样性). rCM 将 score distillation 作为 long-skip regularizer, 直接对大时间步提供 teacher 监督, 弥补 sCM 在大 时 JVP 主导导致的误差累积.
Figure 2 解读: 蒸馏方法的高层对比表。Consistency Model 仅有 forward divergence (高多样性、低质量); Score Distillation 仅有 reverse divergence (高质量、mode collapse); GAN 两者兼有但难调; rCM 同时具有两种 divergence, 既易调又高质高多样性。
Figure 4 解读: rCM 的核心架构图。左侧: sCM 的 forward consistency 目标, 从 出发加噪, 通过 JVP 计算 tangent, 执行瞬时自洽性约束。误差从小 向大 传播累积。右侧: DMD 的 reverse divergence, 从 出发用 student 多步 rollout 生成样本 , 然后用 fake score network 与 teacher score 的差值提供梯度。两条路径互补: forward 保多样性, reverse 修质量。
3. Method (方法)
3.1 背景: sCM 与 Score Distillation
Diffusion Model 采样 PF-ODE:
一致性函数 , 满足边界条件 , 参数化为:
sCM 使用 TrigFlow 调度 , preconditioning .
sCM Loss (forward divergence, mode-covering):
其中 , , .
Score Distillation (reverse divergence, mode-seeking) - DMD Loss:
3.2 适配任意 Noise Schedule
Teacher 通常用 rectified flow (), 而 sCM 用 TrigFlow. 通过信噪比匹配构造 wrapped teacher:
其中 是从 TrigFlow 时间到原始时间的映射, 通过匹配 求得.
3.3 Rollout 策略
Student 生成样本用于 DMD loss 和 fake score 训练. 交替执行反向去噪和前向加噪:
- 随机选择步数
- 仅对最后一步反向传播 DMD loss
- 随机采样 , 设 确保时间步单调递减
3.4 稳定时间导数计算
JVP 的时间导数 中, 因三角函数时间嵌入的振荡性而不稳定.
策略 1: Semi-Continuous Time (适用于 2B 以下 T2I):
策略 2: High-Precision Time (适用于 10B+ 和视频):
- 使用
torch.amp.autocast对所有时间嵌入层强制 FP32 精度 - 保持原生连续时间导数的完整 JVP 计算
3.5 工程基础设施
3.5.1 FlashAttention-2 JVP Kernel
标准 attention: , ,
JVP 需要计算 tangent :
关键: O 和 tO 可以在同一个 streaming loop 中计算, 类似 FlashAttention-2 的 block-wise tiling, 用 Triton 实现.
3.5.2 FSDP 兼容性
定义 JVP 基类继承 torch.nn.Module, 每层实现 _forward 和 _forward_jvp 两个接口:
class JVP(torch.nn.Module):
def forward(self, *args, **kwargs):
withT = kwargs.pop("withT", False)
if withT:
return self._forward_jvp(*args, **kwargs)
else:
return self._forward(*args, **kwargs)withT=True时接收 primal+tangent, 返回 primal+tangent (包装为TensorWithT)- FSDP 粒度与 layer 边界对齐即可兼容
- Attention block 用自定义 FlashAttention-2 JVP kernel 替换, 其余模块用
torch.func.jvp
3.5.3 Context Parallelism (CP)
Ulysses 策略: 输入张量 [B, H, L, C] 沿序列维度 L 切分到 P 个 GPU, 通过 all-to-all 重分布 QKV. JVP 自然扩展: tangent 的 QKV 同样参与分布, 使用 FlashAttention-2 JVP kernel 处理本地 attention.
3.6 训练伪代码与代码映射
3.6.1 Algorithm 1: rCM 完整训练算法
# rCM training loop
theta = copy_weights(theta_teacher)
theta_fake = copy_weights(theta_teacher)
for step in range(1, num_iterations + 1):
if step <= tangent_warmup or step % student_update_freq == 0:
x0 = sample_batch(dataset)
eps = torch.randn_like(x0)
t = sample_time(p_g)
x_t = torch.cos(t) * x0 + torch.sin(t) * eps
tangent = jvp_tangent(F_theta_minus, F_teacher, x_t, t)
warmup_ratio = min(1.0, step / tangent_warmup)
guidance = build_rcm_guidance(
student_prev=F_theta_minus(x_t, t),
teacher=F_teacher(x_t, t),
tangent=tangent,
x_t=x_t,
t=t,
warmup_ratio=warmup_ratio,
)
loss = consistency_loss(F_theta(x_t, t), F_theta_minus(x_t, t), guidance)
if step > tangent_warmup:
rollout_steps = sample_uniform(1, max_rollout_steps)
x0_student = backward_rollout(student=F_theta, num_steps=rollout_steps)
loss += dmd_loss(x0_student, fake_score=F_fake, teacher_score=F_teacher)
optimize(loss, params=theta)
else:
rollout_steps = sample_uniform(1, max_rollout_steps)
x0_student = backward_rollout(student=F_theta_minus, num_steps=rollout_steps)
eps = torch.randn_like(x0_student)
t = sample_time(p_d)
x_t = torch.cos(t) * x0_student + torch.sin(t) * eps
target = torch.cos(t) * eps - torch.sin(t) * x0_student
loss_fake = flow_matching_loss(F_fake(x_t, t), target)
optimize(loss_fake, params=theta_fake)3.6.2 代码映射表
| 论文概念 | 代码路径 (NVlabs/rcm) | 说明 |
|---|---|---|
| rCM 训练主循环 | rcm/training/ | Generator + Critic 交替训练 |
| sCM Loss (Eq. 4) | rcm/training/ | tangent normalization, JVP 计算 |
| DMD Loss (Eq. 6) | rcm/training/ | student rollout + fake score gradient |
| FlashAttention-2 JVP Kernel | rcm/utils/ | Triton 实现, 支持 self/cross attention |
| JVP 基类 | rcm/networks/ | JVP(torch.nn.Module), _forward/_forward_jvp |
| Network Restructuring | rcm/networks/ | RMSNorm 等层的 JVP 适配 |
| Wrapped Teacher (Eq. 3) | rcm/models/ | TrigFlow-consistent teacher wrapper |
| T2V Inference | rcm/inference/wan2pt1_t2v_rcm_infer.py | 支持 1-4 步采样 |
| T2I Inference | rcm/inference/ | Cosmos-Predict2 推理脚本 |
| Training Configs | rcm/configs/ | Hydra 配置文件 |
3.6.3 推理示例
PYTHONPATH=. python rcm/inference/wan2pt1_t2v_rcm_infer.py \
--dit_path assets/checkpoints/rCM_Wan2.1_T2V_1.3B_480p.pt \
--num_samples 5 --num_steps 4 --sigma_max 80 \
--prompt "A cinematic shot of a snowy mountain at sunrise"4. Experimental Setup (实验设置)
4.1 训练配置
Table 4: 关键训练超参数
| 参数 | Cosmos-Predict2 T2I (0.6B/2B/14B) | Wan2.1 T2V (1.3B/14B) |
|---|---|---|
| EMA Length | 0.05 | 0.05 |
| Batch Size | 1024 / 512 / 256 | 256 / 64 |
| Context Parallel Size | 1 | 1 / 10 |
| Learning Rate (student) | 1e-6 | 2e-6 / 1e-6 |
| Learning Rate (fake score) | 2e-7 | 4e-7 / 1e-7 |
| CFG Scale | 4.5 | 5.0 |
| Student Update Frequency | 5 | 5 / 10 |
| Max Simulation Steps | 4 | 4 |
| Tangent Warmup | 0 | 1000 / 200 |
| Total Iterations | 80k / 50k / 25k | 10k / 1600 |
| 80 | - | |
| Optimizer | AdamW () | AdamW () |
| Weight Decay | 0.01 | 0.01 |
| Gradient Clipping | 禁用 | 禁用 |
采样时间步 (4-step):
4.2 其他实验设置
论文未详细说明。
5. Experimental Results (实验结果)
5.1 Text-to-Image (GenEval)
Figure 5 解读: 少步 T2I 生成的定性对比。rCM 在 Cosmos-Predict2 14B 上能渲染精细文字细节如 “Casio G-Shock”、“11:44 AM”、“Thursday, March 22nd”, 这是其他蒸馏方法难以做到的。相比 SDXL-DMD2、SDXL-Lightning 等方法, rCM 的细节保真度显著更高。
Table 1: GenEval 主要结果
| Model | Params | NFE | Overall | Single Obj | Two Obj | Counting | Colors | Position | Color Attr |
|---|---|---|---|---|---|---|---|---|---|
| Pretrained | |||||||||
| Cosmos-Predict2 | 14B | 35x2 | 0.84 | 1.00 | 0.98 | 0.79 | 0.90 | 0.64 | 0.72 |
| FLUX.1-schnell | 12B | 5x1 | 0.66 | 0.98 | 0.81 | 0.74 | 0.79 | 0.22 | 0.45 |
| Distilled | |||||||||
| Cosmos-Predict2 + DMD2 | 2B | 4 | 0.80 | 0.99 | 0.98 | 0.70 | 0.87 | 0.57 | 0.72 |
| Cosmos-Predict2 + rCM | 14B | 4 | 0.83 | 1.00 | 0.98 | 0.80 | 0.86 | 0.59 | 0.73 |
| Cosmos-Predict2 + rCM | 14B | 1 | 0.82 | 1.00 | 0.98 | 0.84 | 0.89 | 0.49 | 0.72 |
5.2 Text-to-Video (VBench 480p)
Table 2: VBench for Wan (480p)
| Model | Params | NFE | Throughput (FPS) | Total Score | Quality | Semantic |
|---|---|---|---|---|---|---|
| Wan2.1 T2V (teacher) | 1.3B | 50x2 | 0.72 | 83.02 | 83.95 | 79.26 |
| Wan2.1 T2V (teacher) | 14B | 50x2 | 0.18 | 83.58 | 84.26 | 80.92 |
| Wan2.1 + DMD2 | 1.3B | 4 | 14.6 | 84.56 | 85.58 | 80.50 |
| Wan2.1 + rCM | 1.3B | 4 | 14.6 | 84.43 | 85.38 | 80.63 |
| Wan2.1 + rCM | 14B | 4 | 4.5 | 84.92 | 85.43 | 82.88 |
| Wan2.1 + rCM | 14B | 2 | 8.3 | 85.05 | 85.57 | 82.95 |
| Wan2.1 + rCM | 14B | 1 | 14.4 | 83.60 | 83.57 | 80.81 |
关键发现:
- rCM 14B 2-step (85.05) 超越 480p teacher baseline (83.58), 加速 ~46x
- rCM 质量匹配甚至超越 DMD2, 同时保持显著更高的多样性
- 4-step 为质量与速度的最佳平衡点
Figure 1 解读: 在 Wan2.1 1.3B 上的 4-step 视频生成对比。sCM 有质量问题 (模糊纹理、不稳定几何); DMD2 和 SiD 出现 mode collapse (5 个随机种子生成的视频物体位置/朝向高度相似); rCM 既修复了 sCM 的质量问题, 又保持了与 teacher 相近的高多样性。
5.3 不同步数的效果
Figure 6 解读: 1-step、2-step、4-step 生成对比。T2I (Cosmos-Predict2 2B): 1-step 即可生成合理图像, 简单 prompt 几乎无差; 4-step 能渲染精细文字。T2V (Wan2.1 1.3B): 1-step 模糊; 2-step 已接近 teacher 质量; 4-step 进一步提升细节和复杂背景文字渲染。
5.4 Cosmos-Predict2 T2V (720p)
Table 3:
| Model | Params | Resolution | NFE | Throughput | T2V Score | I2V Score |
|---|---|---|---|---|---|---|
| Cosmos-Predict2 T2V | 2B | 1280x704x93 | 35x2 | 0.32 | 83.03 | 88.6 |
| Cosmos-Predict2 + rCM | 2B | 1280x704x93 | 4 | 4.6 | 84.40 | 88.2 |
rCM 在 720p 上同样超越 teacher, 吞吐量提升 ~14x.
5.5 Ablation: 的影响
Figure 7 解读: 不同 下 Wan2.1 1.3B 4-step rCM 的视频生成。 (强 DMD): VBench 84.32, 多样性低; : 84.57; : 84.43, sweet spot, 质量好且多样性高; : 82.68, 质量下降。每行 5 个随机种子, 可见 越大多样性越低。
5.6 sCM 的质量问题
Figure 3 解读: 纯 sCM 蒸馏的质量问题展示。上排 (T2I): Cosmos-Predict2 2B/14B, Easy Case 下 sCM 接近 teacher, 但 Hard Case (手表文字渲染) 严重退化, 且增大模型规模无法解决。下排 (T2V): Wan2.1 1.3B + sCM, 视频帧 40-50 出现模糊纹理和不稳定几何。这些问题源于 sCM 的误差累积和 forward divergence 的 mode-covering 特性。
总结
rCM 的核心 insight 简洁而有效: sCM 的 forward divergence 和 DMD 的 reverse divergence 天然互补. Forward 保多样性, reverse 修质量. 仅需一个超参数 即可跨模型、跨任务泛化, 无需 GAN tuning 或复杂超参搜索. 工程上, FlashAttention-2 JVP kernel + FSDP 重构 + CP 支持使得 JVP-based 蒸馏首次可扩展到 10B+ 模型和高维视频数据, 实现 14 步 15x50x 加速.