Causal Forcing++: Scalable Few-Step Autoregressive Diffusion Distillation for Real-Time Interactive Video Generation

Paper: arXiv:2605.15141 Code: thu-ml/Causal-Forcing; shengshu-ai/minWM Code reference: main @ b2c9ebb3 (2026-05-15) for thu-ml/Causal-Forcing; main @ b28f0486 (2026-05-12) for shengshu-ai/minWM README/project skeleton

1. Motivation(研究动机)

实时交互式视频生成需要同时满足三件事:低首帧延迟、可流式自回归 rollout、以及在用户动作/条件变化时足够细粒度地响应。Causal Forcing - Autoregressive Diffusion Distillation Done Right for High-Quality Real-Time Interactive Video Generation 已经证明,把双向视频扩散模型蒸馏成 few-step AR student 可以在 chunk-wise 4-step 设定下获得较高质量,但 chunk-wise 生成的响应粒度仍然粗,且 4-step 采样的交互延迟不可忽略。

本文把目标推进到更激进的设定:frame-wise autoregression + 1/2 sampling steps。困难不只是“步数少”,而是两类误差会叠加:chunk 更小会增加 AR 调用次数;每次调用步数更少会放大单步近似误差。二者在 self-rollout 中变成 exposure bias,导致运动塌缩、场景崩坏或后续帧漂移。

核心瓶颈是 few-step AR student 的初始化:Self Forcing 式 ODE 初始化来自双向 teacher,目标与 AR 条件分布不对齐;直接用 multi-step AR diffusion 初始化缺少 few-step 能力;Causal Forcing 的 causal ODE 初始化目标正确,但需要离线生成并存储完整 PF-ODE 轨迹,80K 视频规模下 Stage 2 总成本约 A800 GPU-hours 且额外存储约 GiB。解决这个瓶颈可以把高质量视频生成推进到实时交互和动作条件世界模型。

2. Idea(核心思想)

Causal Forcing++ 的核心洞察是:causal consistency distillation(causal CD)与 causal ODE distillation 学习同一个 AR-conditional flow map,但监督形式从“离线完整轨迹上的大跨度回归”改成“真实数据上相邻时间步的一次在线 teacher ODE step”。 因此它保留 causal ODE 的目标正确性,同时去掉预计算/存储 PF-ODE trajectory 的扩展瓶颈。

相对 Self Forcing,差异在于 teacher/student 都按 AR 条件前缀 建模,不把双向 teacher 的条件期望当作 frame-level AR 目标;相对 Causal Forcing - Autoregressive Diffusion Distillation Done Right for High-Quality Real-Time Interactive Video Generation,差异在于 Stage 2 不再回归 的完整 teacher trajectory,而是在 的局部一致性上训练 EMA consistency target。直观上,局部 target 更短、更平滑,优化难度更低;而目标流映射的误差只受 ODE solver 数值误差控制。

3. Method(方法)

3.1 Overall framework

Figure 1 解读:整套流程仍是三阶段 AR 蒸馏。Stage 1 先做 teacher-forcing AR diffusion training;Stage 2 是本文替换的关键,把 Causal Forcing 的 causal ODE init 改成 causal CD init;Stage 3 继续用 asymmetric DMD + self-rollout refinement,把 few-step AR student 推到最终可交互模型。图中也对比了 Self Forcing、Causal Forcing 与 Causal Forcing++ 的 teacher 目标和扩展成本。

3.2 为什么旧初始化会失败

Figure 2 解读:Self Forcing ODE 和 multi-step AR diffusion 在更激进的 frame-wise / 1-step 条件下迅速退化;Causal Forcing 的 causal ODE 目标正确,但成本随数据量、轨迹长度和配置变化而增长。这个图支持本文的设计约束:初始化必须同时是 AR、few-step、scalable。

3.3 Causal CD objective

Causal ODE distillation 对 AR teacher 的 PF-ODE 中间态做监督:

其最优解是 teacher 的 AR-conditional flow map:

Causal CD 把这个目标改写成相邻时间步一致性:

这里 由 AR teacher 在 ground-truth prefix 条件下从 做一次 ODE step 得到, 是 stop-gradient EMA student, 是时间权重, 是预定义范数下的距离。flow-matching 参数化为:

在标准 consistency distillation 分析下,最优模型与目标 flow map 的误差满足:

其中 是相邻时间步最大间隔,ODE solver 为 阶。也就是说,causal CD 与 causal ODE 指向同一个 AR flow map;差别是 causal ODE 用离线大跨度轨迹回归,causal CD 用在线局部一致性约束逼近。

Figure 3a–3b 解读:左侧展示 causal CD 在 Stage 2 末和 Stage 3 init 后均高于或不低于 causal ODE;右侧展示 causal CD 把 80K 视频规模的 Stage 2 时间从约 降到约 A800 GPU-hours,并把 ODE paired data 的额外存储从约 GiB 降到 0。

3.4 为什么不直接用 causal DMD 做初始化

论文还分析了 causal score distillation / causal DMD。DMD 的分布匹配梯度可写为:

其 AR teacher-forcing 版本为:

Figure 5a–5b 解读:causal DMD 的早期帧可能更锐,但 mode-seeking 分布匹配在 history shift 下会把概率质量推向 poor-quality 区域;AR rollout 中前缀误差累积后,后续帧更容易暴露 bias。causal CD 的局部一致性 target 更接近 teacher flow,后续 self-rollout 更稳定。

3.5 Action-conditioned world model extension

Figure 4 解读:作者把 Causal Forcing++ 用到 Genie3 风格的 camera-pose-conditioned world model:先用 WorldPlay 构建带相机姿态的数据,再用 PRoPE 把 Wan2.1-1.3B 微调成双向 camera-pose-conditioned diffusion model,最后用 Causal Forcing++ 蒸馏成可交互 AR world model。公开的 shengshu-ai/minWMmain@b28f0486 仅有 README 级 pipeline,训练/推理实现仍在整理中。

3.6 Pseudocode(基于公开代码)

Stage 2 causal CD 的核心实现对应 model/naive_consistency.py::NaiveConsistency.generator_losstrainer/naive_cd.py::Trainer

class CausalCDTrainer:
    def step(self, batch):
        prompts = batch["prompts"]
        clean_latent = batch["clean_latent"].to(torch.bfloat16)
        cond = text_encoder(prompts)
        uncond = cached_text_encoder([negative_prompt] * len(prompts))
 
        loss, logs = model.generator_loss(
            conditional_dict=cond,
            unconditional_dict=uncond,
            clean_latent=clean_latent,
            ema_model=generator_ema,
        )
        loss.backward()
        clip_grad_norm_(model.generator, max_grad_norm_generator)
        optimizer.step()
        generator_ema.update(model.generator)
def causal_cd_generator_loss(clean_latent, cond, uncond):
    B, F = clean_latent.shape[:2]
    k = random.randrange(discrete_cd_N - 1)
    t = scheduler.timesteps[k]
    t_next = scheduler.timesteps[k + 1]
 
    noise = torch.randn_like(clean_latent)
    x_t = scheduler.add_noise(clean_latent, noise, timestep=t)
 
    with torch.no_grad():
        v_c, _ = teacher(x_t, cond, timestep=t, clean_x=clean_latent)
        v_u, _ = teacher(x_t, uncond, timestep=t, clean_x=clean_latent)
        v = v_u + guidance_scale * (v_c - v_u)
        x_t_next = x_t - ((t - t_next) / 1000.0) * v
 
    _, pred_t = generator(x_t, cond, timestep=t, clean_x=clean_latent)
    with torch.no_grad():
        ema_model.copy_to(generator_ema)
        _, target_next = generator_ema(x_t_next, cond, timestep=t_next, clean_x=clean_latent)
 
    return F.mse_loss(pred_t, target_next)

Stage 3 asymmetric DMD 对应 model/dmd.py::DMD._compute_kl_gradconfigs/causal_forcing_dmd_framewise_2step.yaml

def dmd_distribution_matching(noisy_latent, clean_estimate, timestep, cond, uncond):
    pred_fake_c = fake_score(noisy_latent, cond, timestep)[1]
    pred_fake_u = fake_score(noisy_latent, uncond, timestep)[1]
    pred_fake = pred_fake_c + fake_guidance_scale * (pred_fake_c - pred_fake_u)
 
    pred_real_c = real_score(noisy_latent, cond, timestep)[1]
    pred_real_u = real_score(noisy_latent, uncond, timestep)[1]
    pred_real = pred_real_c + real_guidance_scale * (pred_real_c - pred_real_u)
 
    grad = pred_fake - pred_real
    normalizer = (clean_estimate - pred_real).abs().mean(dim=[1, 2, 3, 4], keepdim=True)
    grad = torch.nan_to_num(grad / normalizer)
    return 0.5 * F.mse_loss(clean_estimate.double(), (clean_estimate.double() - grad.double()).detach())

Frame-wise 1/2-step 的 first-frame trick 对应 pipeline/self_forcing_training.py::SelfForcingTrainingPipeline 和配置中的 denoising_step_list_first_chunk

def choose_block_schedule(block_index):
    if block_index == 0 and denoising_step_list_first_chunk is not None:
        # Causal Forcing++ 1/2-step: first latent frame still uses 4-step schedule.
        return denoising_step_list_first_chunk  # [1000, 750, 500, 250]
    return denoising_step_list  # e.g. 2-step [1000, 500] or 1-step [1000]

3.7 Code-to-paper mapping

Code reference: main @ b2c9ebb3 (2026-05-15) — pseudocode and mapping based on this commit; minWM checked at main @ b28f0486 (2026-05-12), but only README/pipeline skeleton is public.

Paper conceptSource fileKey class/function
Stage 1 teacher-forcing AR diffusiontrain.py, trainer/diffusion.py, configs/ar_diffusion_tf_framewise.yamlDiffusionTrainer, config trainer: diffusion
Stage 2 causal CD inittrain.py, trainer/naive_cd.py, model/naive_consistency.py, configs/causal_cd_framewise.yamlConsistencyDistillationTrainer, NaiveConsistency.generator_loss, discrete_cd_N: 48
Online AR teacher one-step targetmodel/naive_consistency.pyteacher conditional/unconditional velocity, CFG, latent_t_next = latent_t - dt * v_pred
EMA consistency targettrainer/naive_cd.py, model/naive_consistency.pyEMA_FSDP, ema_model.copy_to(self.generator_ema), F.mse_loss(cm_pred_t, cm_pred_t_next)
Stage 3 asymmetric DMDtrainer/distillation.py, model/dmd.py, configs/causal_forcing_dmd_framewise_2step.yamlScoreDistillationTrainer, DMD._compute_kl_grad, trainer: score_distillation
Frame-wise 2-step + first-frame 4-step trickconfigs/causal_forcing_dmd_framewise_2step.yaml, pipeline/self_forcing_training.pydenoising_step_list: [1000, 500], denoising_step_list_first_chunk: [1000, 750, 500, 250]
Inference entrypoint and released checkpointsREADME.md, inference.py, configs/causal_forcing_dmd_framewise_1step.yaml, configs/causal_forcing_dmd_framewise_2step.yamlCLI commands for 1-step/2-step Causal Forcing++
Action-conditioned minWM extensionshengshu-ai/minWM/README.mdpublic repo currently lists data pipeline, bidirectional action-conditioned finetuning, AR finetuning, AR distillation, real-time inference; implementation files not released at checked commit

4. Experimental Setup(实验设置)

模型与生成设定。 基座模型为 Wan2.1-1.3B;生成分辨率为 ,81 frames,采用 frame-wise autoregressive generation。Stage 3 DMD 中 generator/student 仍基于 Wan2.1-1.3B,但两个 score models 使用 Wan2.1-14B。

三阶段训练。 Stage 1 是 teacher-forcing AR diffusion training;Stage 2 是 few-step initialization,本文用 causal CD,square norm、48 discretized timesteps、Euler solver;Stage 3 是 asymmetric DMD with self-rollout。论文报告三阶段训练步数分别为 20K、5K、1K,batch size 为 64。公开代码的 Stage 2 配置 configs/causal_cd_framewise.yaml 使用 trainer: consistency_distillationgenerator_ckpt: checkpoints/framewise/ar_diffusion.ptlr: 2.0e-06batch_size: 1,README 的推荐启动是 8 nodes × 8 processes,因此每进程 batch 1 对应全局 64;该配置文件还包含 total_batch_size: 8,应视为公开模板字段而不是论文表格中的最终全局 batch 口径。

数据。 Stage 1/2 使用 80K dataset,其中包含从 OpenVid 采样的视频;Stage 3 使用 VidProM。action-conditioned world model extension 先用 WorldPlay 构建 camera-pose-annotated training dataset,再通过 PRoPE 注入 pose 条件并蒸馏。

评价。 主要基准为 VBench 和 VisionReward。VBench 报告 Total、Quality、Semantic,并使用 Causal Forcing 的 100 prompts 单独评估 Dynamic Degree;VisionReward 也使用同一组 100 prompts,并报告 overall score 与 Instruction Following。所有指标乘以 100。效率指标包括 first-frame latency 与 throughput;论文特别说明这些指标在单张 A800 上测量,且不包含 VAE 相关时间成本。

5. Experimental Results(实验结果)

5.1 Main comparison

ModelThroughput ↑Latency ↓VBench Total ↑Quality ↑Semantic ↑Dynamic ↑VisionReward ↑Instruct ↑
CausVid10.40.6081.3383.9870.72625.74112
Self Forcing10.40.6083.7484.4880.77575.82048
Causal Forcing10.40.6084.0484.5981.84686.32656
Causal Forcing++ (1-step)20.70.2783.3584.5078.75665.41238
Causal Forcing++ (2-step)14.10.2784.1484.8981.13646.66151
Causal Forcing++ (4-step)8.690.2784.1084.9480.75716.79847

2-step Causal Forcing++ 在 VBench Total 上达到 84.14,高于 Causal Forcing 的 84.04;Quality 为 84.89,高于 84.59;VisionReward 为 6.661,高于 6.326。首帧 latency 从 0.60s 降到 0.27s,约 50% 下降;throughput 从 10.4 FPS 提升到 14.1 FPS。4-step Causal Forcing++ 的 Quality、Dynamic Degree 和 VisionReward 更高,但 throughput 低于 2-step。

Figure 6 解读:视觉对比显示 CausVid / Self Forcing 在动态或一致性上偏弱;Causal Forcing++ 在更少 frame-wise steps 下仍能接近或超过 Causal Forcing 的质量和运动表现,并在色彩、亮度等审美维度上更强。

5.2 Ablation on initialization

SettingInitializationTotal ↑Quality ↑Semantic ↑Dynamic ↑Vision ↑Instruct ↑Stage 2 Time ↓Extra Storage ↓
1-stepSelf Forcing78.8779.8574.9501.992-1250001500
1-stepAR diffusion80.5480.9778.8401.101-14-0
1-stepCausal ODE83.0683.8879.77465.46440116001900
1-stepCausal DMD82.3483.5077.71624.8682029000
1-stepCausal CD83.3584.5078.75665.4123829000
2-stepCausal ODE83.7784.4281.19576.22446116001900
2-stepCausal CD84.1484.8981.13646.6615129000
4-stepCausal ODE83.7884.9079.28756.43542116001900
4-stepCausal CD84.1084.9480.75716.7984729000

Causal CD 在 1/2/4-step 中总体匹配或超过 causal ODE,同时把 Stage 2 时间从 11600 降到 2900 A800 GPU-hours,额外存储从 1900 GiB 降到 0。multi-step AR diffusion 在 1-step 几乎没有动态,说明 few-step 初始化不是可选项;causal DMD 虽省成本,但 VisionReward 相比 causal CD 在各 step 下约低 0.5 左右,且更易受 rollout history shift 影响。

Figure 7 解读:1-step 下 AR diffusion 严重模糊且几乎无运动,causal ODE 有场景 collapse,causal DMD 把局部物体结构糊在一起;2-step/4-step 下 causal CD 也更稳定。这个视觉结果与表格中的 Dynamic、VisionReward 和 Instruction Following 一致。

5.3 Limitations / caveats

作者未单列 limitations。需要注意的边界是:效率指标在单 A800 且不含 VAE 时间;1/2-step 仍使用 first-frame 4-step trick,因此首帧并非真正 1-step;action-conditioned world model extension 主要是 demo/定性展示,公开 minWM 仓库在当前 commit 尚未释放完整训练/推理代码;除 README 与配置外,若要复现实验中的完整 80K 数据、VidProM 过滤集和内部训练日志,论文未详细说明。