rCM: Large Scale Diffusion Distillation via Score-Regularized Continuous-Time Consistency

Authors: Kaiwen Zheng, Yuji Wang, Qianli Ma, Huayu Chen, Jintao Zhang, Yogesh Balaji, Jianfei Chen, Ming-Yu Liu, Jun Zhu, Qinsheng Zhang Affiliations: Tsinghua University, NVIDIA arXiv: 2510.08431 Venue: ICLR 2026

1. Motivation (研究动机)

1.1 问题背景

连续时间一致性模型 (sCM, MeanFlow) 在理论上优雅, 但在大规模应用中面临三大挑战:

  1. 质量瓶颈: sCM 的 forward divergence (mode-covering) 特性导致细节生成退化 (如文字渲染失败、视频时序不一致)
  2. 误差累积: JVP 自反馈信号在 BF16 精度下数值脆弱, 随时间 增大误差被放大; 当 时, teacher 监督消失, 学习动态被 JVP 主导
  3. 工程限制: PyTorch 原生 torch.func.jvp 不兼容 FlashAttention-2、FSDP、Context Parallelism, 无法训练 10B+ 模型

1.2 现有方法的局限

论文未详细说明

2. Idea (核心思想)

2.1 核心贡献

贡献描述
rCM 框架联合 forward divergence (sCM) + reverse divergence (DMD) 的蒸馏方法, 即可平衡质量与多样性
FlashAttention-2 JVP Kernel基于 Triton 实现的 JVP 核, 支持 self/cross attention, 兼容 FSDP/CP
稳定时间导数计算Semi-continuous time (有限差分) 和 High-precision time (FP32 时间嵌入) 两种策略
大规模验证首次将连续时间一致性蒸馏扩展到 Cosmos-Predict2 (14B) T2I 和 Wan2.1 (14B) T2V, 14 步生成, 加速 15x50x

2.2 rCM: 联合目标

核心洞察: sCM (forward divergence) 提供 mode-covering (高多样性但低质量), DMD (reverse divergence) 提供 mode-seeking (高质量但低多样性). rCM 将 score distillation 作为 long-skip regularizer, 直接对大时间步提供 teacher 监督, 弥补 sCM 在大 时 JVP 主导导致的误差累积.

Figure 2 解读: 蒸馏方法的高层对比表。Consistency Model 仅有 forward divergence (高多样性、低质量); Score Distillation 仅有 reverse divergence (高质量、mode collapse); GAN 两者兼有但难调; rCM 同时具有两种 divergence, 既易调又高质高多样性

Figure 4 解读: rCM 的核心架构图。左侧: sCM 的 forward consistency 目标, 从 出发加噪, 通过 JVP 计算 tangent, 执行瞬时自洽性约束。误差从小 向大 传播累积。右侧: DMD 的 reverse divergence, 从 出发用 student 多步 rollout 生成样本 , 然后用 fake score network 与 teacher score 的差值提供梯度。两条路径互补: forward 保多样性, reverse 修质量。

3. Method (方法)

3.1 背景: sCM 与 Score Distillation

Diffusion Model 采样 PF-ODE:

一致性函数 , 满足边界条件 , 参数化为:

sCM 使用 TrigFlow 调度 , preconditioning .

sCM Loss (forward divergence, mode-covering):

其中 , , .

Score Distillation (reverse divergence, mode-seeking) - DMD Loss:

3.2 适配任意 Noise Schedule

Teacher 通常用 rectified flow (), 而 sCM 用 TrigFlow. 通过信噪比匹配构造 wrapped teacher:

其中 是从 TrigFlow 时间到原始时间的映射, 通过匹配 求得.

3.3 Rollout 策略

Student 生成样本用于 DMD loss 和 fake score 训练. 交替执行反向去噪和前向加噪:

  • 随机选择步数
  • 仅对最后一步反向传播 DMD loss
  • 随机采样 , 设 确保时间步单调递减

3.4 稳定时间导数计算

JVP 的时间导数 中, 因三角函数时间嵌入的振荡性而不稳定.

策略 1: Semi-Continuous Time (适用于 2B 以下 T2I):

策略 2: High-Precision Time (适用于 10B+ 和视频):

  • 使用 torch.amp.autocast 对所有时间嵌入层强制 FP32 精度
  • 保持原生连续时间导数的完整 JVP 计算

3.5 工程基础设施

3.5.1 FlashAttention-2 JVP Kernel

标准 attention: , ,

JVP 需要计算 tangent :

关键: O 和 tO 可以在同一个 streaming loop 中计算, 类似 FlashAttention-2 的 block-wise tiling, 用 Triton 实现.

3.5.2 FSDP 兼容性

定义 JVP 基类继承 torch.nn.Module, 每层实现 _forward_forward_jvp 两个接口:

class JVP(torch.nn.Module):
    def forward(self, *args, **kwargs):
        withT = kwargs.pop("withT", False)
        if withT:
            return self._forward_jvp(*args, **kwargs)
        else:
            return self._forward(*args, **kwargs)
  • withT=True 时接收 primal+tangent, 返回 primal+tangent (包装为 TensorWithT)
  • FSDP 粒度与 layer 边界对齐即可兼容
  • Attention block 用自定义 FlashAttention-2 JVP kernel 替换, 其余模块用 torch.func.jvp

3.5.3 Context Parallelism (CP)

Ulysses 策略: 输入张量 [B, H, L, C] 沿序列维度 L 切分到 P 个 GPU, 通过 all-to-all 重分布 QKV. JVP 自然扩展: tangent 的 QKV 同样参与分布, 使用 FlashAttention-2 JVP kernel 处理本地 attention.

3.6 训练伪代码与代码映射

3.6.1 Algorithm 1: rCM 完整训练算法

# rCM training loop
theta = copy_weights(theta_teacher)
theta_fake = copy_weights(theta_teacher)
 
for step in range(1, num_iterations + 1):
    if step <= tangent_warmup or step % student_update_freq == 0:
        x0 = sample_batch(dataset)
        eps = torch.randn_like(x0)
        t = sample_time(p_g)
        x_t = torch.cos(t) * x0 + torch.sin(t) * eps
 
        tangent = jvp_tangent(F_theta_minus, F_teacher, x_t, t)
        warmup_ratio = min(1.0, step / tangent_warmup)
        guidance = build_rcm_guidance(
            student_prev=F_theta_minus(x_t, t),
            teacher=F_teacher(x_t, t),
            tangent=tangent,
            x_t=x_t,
            t=t,
            warmup_ratio=warmup_ratio,
        )
 
        loss = consistency_loss(F_theta(x_t, t), F_theta_minus(x_t, t), guidance)
        if step > tangent_warmup:
            rollout_steps = sample_uniform(1, max_rollout_steps)
            x0_student = backward_rollout(student=F_theta, num_steps=rollout_steps)
            loss += dmd_loss(x0_student, fake_score=F_fake, teacher_score=F_teacher)
 
        optimize(loss, params=theta)
    else:
        rollout_steps = sample_uniform(1, max_rollout_steps)
        x0_student = backward_rollout(student=F_theta_minus, num_steps=rollout_steps)
        eps = torch.randn_like(x0_student)
        t = sample_time(p_d)
        x_t = torch.cos(t) * x0_student + torch.sin(t) * eps
        target = torch.cos(t) * eps - torch.sin(t) * x0_student
 
        loss_fake = flow_matching_loss(F_fake(x_t, t), target)
        optimize(loss_fake, params=theta_fake)

3.6.2 代码映射表

论文概念代码路径 (NVlabs/rcm)说明
rCM 训练主循环rcm/training/Generator + Critic 交替训练
sCM Loss (Eq. 4)rcm/training/tangent normalization, JVP 计算
DMD Loss (Eq. 6)rcm/training/student rollout + fake score gradient
FlashAttention-2 JVP Kernelrcm/utils/Triton 实现, 支持 self/cross attention
JVP 基类rcm/networks/JVP(torch.nn.Module), _forward/_forward_jvp
Network Restructuringrcm/networks/RMSNorm 等层的 JVP 适配
Wrapped Teacher (Eq. 3)rcm/models/TrigFlow-consistent teacher wrapper
T2V Inferencercm/inference/wan2pt1_t2v_rcm_infer.py支持 1-4 步采样
T2I Inferencercm/inference/Cosmos-Predict2 推理脚本
Training Configsrcm/configs/Hydra 配置文件

3.6.3 推理示例

PYTHONPATH=. python rcm/inference/wan2pt1_t2v_rcm_infer.py \
  --dit_path assets/checkpoints/rCM_Wan2.1_T2V_1.3B_480p.pt \
  --num_samples 5 --num_steps 4 --sigma_max 80 \
  --prompt "A cinematic shot of a snowy mountain at sunrise"

4. Experimental Setup (实验设置)

4.1 训练配置

Table 4: 关键训练超参数

参数Cosmos-Predict2 T2I (0.6B/2B/14B)Wan2.1 T2V (1.3B/14B)
EMA Length0.050.05
Batch Size1024 / 512 / 256256 / 64
Context Parallel Size11 / 10
Learning Rate (student)1e-62e-6 / 1e-6
Learning Rate (fake score)2e-74e-7 / 1e-7
CFG Scale4.55.0
Student Update Frequency 55 / 10
Max Simulation Steps 44
Tangent Warmup 01000 / 200
Total Iterations80k / 50k / 25k10k / 1600
80-
OptimizerAdamW ()AdamW ()
Weight Decay0.010.01
Gradient Clipping禁用禁用

采样时间步 (4-step):

4.2 其他实验设置

论文未详细说明

5. Experimental Results (实验结果)

5.1 Text-to-Image (GenEval)

Figure 5 解读: 少步 T2I 生成的定性对比。rCM 在 Cosmos-Predict2 14B 上能渲染精细文字细节如 “Casio G-Shock”、“11:44 AM”、“Thursday, March 22nd”, 这是其他蒸馏方法难以做到的。相比 SDXL-DMD2、SDXL-Lightning 等方法, rCM 的细节保真度显著更高。

Table 1: GenEval 主要结果

ModelParamsNFEOverallSingle ObjTwo ObjCountingColorsPositionColor Attr
Pretrained
Cosmos-Predict214B35x20.841.000.980.790.900.640.72
FLUX.1-schnell12B5x10.660.980.810.740.790.220.45
Distilled
Cosmos-Predict2 + DMD22B40.800.990.980.700.870.570.72
Cosmos-Predict2 + rCM14B40.831.000.980.800.860.590.73
Cosmos-Predict2 + rCM14B10.821.000.980.840.890.490.72

5.2 Text-to-Video (VBench 480p)

Table 2: VBench for Wan (480p)

ModelParamsNFEThroughput (FPS)Total ScoreQualitySemantic
Wan2.1 T2V (teacher)1.3B50x20.7283.0283.9579.26
Wan2.1 T2V (teacher)14B50x20.1883.5884.2680.92
Wan2.1 + DMD21.3B414.684.5685.5880.50
Wan2.1 + rCM1.3B414.684.4385.3880.63
Wan2.1 + rCM14B44.584.9285.4382.88
Wan2.1 + rCM14B28.385.0585.5782.95
Wan2.1 + rCM14B114.483.6083.5780.81

关键发现:

  • rCM 14B 2-step (85.05) 超越 480p teacher baseline (83.58), 加速 ~46x
  • rCM 质量匹配甚至超越 DMD2, 同时保持显著更高的多样性
  • 4-step 为质量与速度的最佳平衡点

Figure 1 解读: 在 Wan2.1 1.3B 上的 4-step 视频生成对比。sCM 有质量问题 (模糊纹理、不稳定几何); DMD2SiD 出现 mode collapse (5 个随机种子生成的视频物体位置/朝向高度相似); rCM 既修复了 sCM 的质量问题, 又保持了与 teacher 相近的高多样性。

5.3 不同步数的效果

Figure 6 解读: 1-step、2-step、4-step 生成对比。T2I (Cosmos-Predict2 2B): 1-step 即可生成合理图像, 简单 prompt 几乎无差; 4-step 能渲染精细文字。T2V (Wan2.1 1.3B): 1-step 模糊; 2-step 已接近 teacher 质量; 4-step 进一步提升细节和复杂背景文字渲染。

5.4 Cosmos-Predict2 T2V (720p)

Table 3:

ModelParamsResolutionNFEThroughputT2V ScoreI2V Score
Cosmos-Predict2 T2V2B1280x704x9335x20.3283.0388.6
Cosmos-Predict2 + rCM2B1280x704x9344.684.4088.2

rCM 在 720p 上同样超越 teacher, 吞吐量提升 ~14x.

5.5 Ablation: 的影响

Figure 7 解读: 不同 下 Wan2.1 1.3B 4-step rCM 的视频生成。 (强 DMD): VBench 84.32, 多样性低; : 84.57; : 84.43, sweet spot, 质量好且多样性高; : 82.68, 质量下降。每行 5 个随机种子, 可见 越大多样性越低。

5.6 sCM 的质量问题

Figure 3 解读: 纯 sCM 蒸馏的质量问题展示。上排 (T2I): Cosmos-Predict2 2B/14B, Easy Case 下 sCM 接近 teacher, 但 Hard Case (手表文字渲染) 严重退化, 且增大模型规模无法解决。下排 (T2V): Wan2.1 1.3B + sCM, 视频帧 40-50 出现模糊纹理和不稳定几何。这些问题源于 sCM 的误差累积和 forward divergence 的 mode-covering 特性。


总结

rCM 的核心 insight 简洁而有效: sCM 的 forward divergence 和 DMD 的 reverse divergence 天然互补. Forward 保多样性, reverse 修质量. 仅需一个超参数 即可跨模型、跨任务泛化, 无需 GAN tuning 或复杂超参搜索. 工程上, FlashAttention-2 JVP kernel + FSDP 重构 + CP 支持使得 JVP-based 蒸馏首次可扩展到 10B+ 模型和高维视频数据, 实现 14 步 15x50x 加速.