Causal Forcing: Autoregressive Diffusion Distillation Done Right for High-Quality Real-Time Interactive Video Generation

Authors: Hongzhou Zhu, Min Zhao, Guande He, Hang Su, Chongxuan Li, Jun Zhu Affiliations: 清华大学, ShengShu, UT Austin, 人大 arXiv: 2602.02214 GitHub: thu-ml/Causal-Forcing Year: 2026

1. Motivation (研究动机)

核心问题:实时交互式视频生成

当前实时交互式视频生成的主流方案是将预训练的双向(bidirectional)扩散模型蒸馏为少步自回归(AR)模型。这个过程面临一个根本性挑战——架构鸿沟(architectural gap)

  • 双向模型: 使用 full attention,每帧都能看到所有其他帧(过去+未来)
  • AR模型: 使用 causal attention,每帧只能看到过去的帧

现有方法(如 Self Forcing)采用两阶段蒸馏:

  1. ODE 蒸馏初始化: 从双向教师蒸馏到 AR 学生
  2. DMD 阶段: 进一步提升质量

核心发现:Self Forcing 的 ODE 蒸馏在理论上存在根本缺陷

Figure 3 解读: 这是本文最核心的图。展示了三种 ODE 蒸馏设置下的 injectivity 性质:

  • (a) 双向教师 → 双向学生: 标准 DMD 设置。Injectivity 在视频级别成立——不同的 noisy video 映射到不同的 clean video。学生能准确学到 flow map。
  • (b) AR 教师 → AR 学生 (Ours): Injectivity 在帧级别成立——每个 noisy frame 映射到唯一的 clean frame。这正是 AR 学生所需要的。
  • (c) 双向教师 → AR 学生 (Self Forcing): Injectivity 被违反!同一个 noisy frame 可能映射到多个不同的 clean frame(因为双向模型的 flow map 依赖其他帧的噪声),导致学生学到的是条件均值 E[x₀|xᵢₜ],产生模糊输出。

2. Idea (核心思想)

2.1 核心思路:让蒸馏数据与学生架构一致

关键思想: 使用 Stage 1 得到的 AR 扩散模型(而非双向模型)作为 ODE 蒸馏的教师。

直观上,方法的目标不是继续让双向 teacher 去“迁就” causal student,而是直接让 teacher 先具备 causal 结构,再做 ODE / DMD 蒸馏。这样可以把架构鸿沟前移到更合适的阶段处理。

2.2 为什么这样能解决问题?

定义 Frame-level injectivity (Definition 3.1): 对于映射 φ^{AR}: (xᵢₜ, t) → xᵢ₀,如果对任意两个 noisy video:

即 AR 教师的 PF-ODE flow map 天然满足帧级单射性——因为 AR 模型只依赖过去帧,第 i 帧的去噪结果完全由 (xᵢₜ, x₀^{<i}) 决定,不受未来帧影响。

2.3 训练目标的直观含义

其中 (xᵢₜ, xᵢ₀) 来自 AR 教师的 PF-ODE 轨迹,x_{gt}^{<i} 是真实干净前缀帧。这个目标对应的是“把 noisy frame 在正确的 causal 条件下还原为 clean frame”。

2.4 三阶段流水线概览

Causal Forcing 采用三阶段方案:

Stage 1: AR Diffusion Training (Teacher Forcing)
    → 获得 AR 扩散教师模型

Stage 2: Causal ODE Distillation
    → 用 AR 教师生成配对数据,蒸馏 AR 学生

Stage 3: Asymmetric DMD
    → 进一步提升为少步生成模型

3. Method (方法)

3.1 Stage 1: 自回归扩散训练——为什么选 Teacher Forcing 而非 Diffusion Forcing?

Figure 4 解读: 对比 Diffusion Forcing (DF) 和 Teacher Forcing (TF) 训练 AR 扩散模型的效果。DF 训练时前缀帧是含噪的(x_t^{<i}),推理时前缀是干净的(x_0^{<i}),存在训练-推理分布不匹配,导致视频崩溃(上排)。TF 训练时前缀帧就是干净帧,与推理一致,生成质量更高(下排)。

理论支撑 (Proposition 3.4):

即 Diffusion Forcing 学到的条件分布与真实条件分布之间有不可消除的 KL 散度。

3.2 Stage 2: Causal ODE 蒸馏——核心创新

关键思想: 使用 Stage 1 得到的 AR 扩散模型(而非双向模型)作为 ODE 蒸馏的教师。

为什么这解决了问题?

定义 Frame-level injectivity (Definition 3.1): 对于映射 φ^{AR}: (xᵢₜ, t) → xᵢ₀,如果对任意两个 noisy video:

即 AR 教师的 PF-ODE flow map 天然满足帧级单射性——因为 AR 模型只依赖过去帧,第 i 帧的去噪结果完全由 (xᵢₜ, x₀^{<i}) 决定,不受未来帧影响。

训练目标:

其中 (xᵢₜ, xᵢ₀) 来自 AR 教师的 PF-ODE 轨迹,x_{gt}^{<i} 是真实干净前缀帧。

3.3 Stage 3: Asymmetric DMD

使用 causal ODE 初始化后的模型进行标准的非对称 DMD 训练,使用 VidProM 数据集,训练 750 步至收敛。

Figure 5 解读: Self Forcing 的 ODE 初始化 + DMD 仍然产生较弱的动态效果和伪影(上排),而 Causal Forcing 的 ODE 初始化 + DMD 产生了更强的动态效果和更高的视觉保真度(下排)。这证明 causal ODE 蒸馏确实为 DMD 提供了正确的初始化。

3.4 扩展:Causal Consistency Distillation

本文还将思想扩展到 Consistency Distillation (CD),用 AR 教师替代双向教师:

3.5 Python 伪代码

import torch
import torch.nn.functional as F
 
# ============================================================
# Stage 1: AR Diffusion Training via Teacher Forcing
# ============================================================
def train_ar_diffusion_teacher_forcing(model, dataloader, num_steps=2000):
    """
    训练AR扩散模型,使用Teacher Forcing策略
    - 训练时前缀帧为干净帧(与推理一致)
    - 使用causal attention mask
    """
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
 
    for step in range(num_steps):
        video = next(dataloader)  # [B, N, C, H, W], N=81 frames
 
        # 随机选择要训练的帧索引
        i = torch.randint(1, N, (B,))  # 第i帧
 
        # 前缀帧: 干净的 (Teacher Forcing的关键)
        prefix_clean = video[:, :i]  # x_0^{<i}, 干净帧作为条件
        target_clean = video[:, i]    # x_0^i, 目标帧
 
        # 对目标帧加噪
        t = torch.rand(B) * T  # 随机时间步
        noise = torch.randn_like(target_clean)
        x_t_i = alpha_t * target_clean + sigma_t * noise  # noisy target frame
 
        # 模型预测 (causal attention: 只看过去帧)
        # 输入: [prefix_clean | x_t_i], 带causal mask
        v_pred = model(x_t_i, prefix_clean, t)  # velocity prediction
 
        # Flow matching loss
        v_target = target_clean - noise  # 或其他参数化
        loss = F.mse_loss(v_pred, v_target)
 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
 
    return model  # AR diffusion teacher
 
 
# ============================================================
# Stage 2: Causal ODE Distillation
# ============================================================
def generate_causal_ode_pairs(ar_teacher, dataloader, timesteps_S):
    """
    从AR教师模型采样PF-ODE轨迹,生成配对训练数据
    关键: 使用AR教师(而非双向教师)保证frame-level injectivity
    """
    paired_data = []
 
    for video_clean in dataloader:
        for i in range(1, N):
            prefix_clean = video_clean[:, :i]  # x_{gt}^{<i}
 
            # 从纯噪声开始,沿AR教师的PF-ODE采样
            x_T_i = torch.randn(B, C, H, W)  # 初始噪声
 
            # 存储轨迹上多个时间步的中间态
            trajectory = {}
            x_t = x_T_i
            for t in reversed(sorted(timesteps_S | {0})):
                trajectory[t] = x_t.clone()
                if t > 0:
                    # ODE step: dx = v_teacher(x_t, prefix_clean, t) dt
                    x_t = ode_step(ar_teacher, x_t, prefix_clean, t)
 
            x_0_i = trajectory[0]  # 教师生成的clean frame
 
            # 配对: {(x_t^i, t) -> x_0^i} for all t in S
            for t in timesteps_S:
                paired_data.append((trajectory[t], prefix_clean, t, x_0_i))
 
    return paired_data
 
 
def train_causal_ode_student(student, paired_data, num_steps=1000):
    """
    Causal ODE蒸馏: 学生从noisy frame直接回归clean frame
    因为AR教师满足frame-level injectivity,学生能准确学到flow map
    """
    optimizer = torch.optim.AdamW(student.parameters(), lr=1e-5)
 
    for step in range(num_steps):
        x_t_i, prefix_clean, t, x_0_i = sample_batch(paired_data)
 
        # 学生直接预测clean frame
        x_0_pred = student(x_t_i, prefix_clean, t)
 
        # MSE回归目标 (Eq. 8)
        loss = F.mse_loss(x_0_pred, x_0_i)
 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
 
    return student
 
 
# ============================================================
# Stage 3: Asymmetric DMD
# ============================================================
def train_asymmetric_dmd(student, bi_teacher, real_scorer, fake_scorer,
                         dataloader, num_steps=750):
    """
    非对称DMD: 使用双向基座模型作为score teacher
    student已经由causal ODE蒸馏初始化(关键!)
    """
    optimizer = torch.optim.AdamW(student.parameters(), lr=1e-6)
 
    for step in range(num_steps):
        prompt = next(dataloader)
 
        # 学生生成fake video (few-step AR generation)
        fake_video = student.generate(prompt, num_steps=1)  # 1-step generation
 
        # 对fake video加噪
        t = torch.rand(B) * T
        noise = torch.randn_like(fake_video)
        x_t = alpha_t * fake_video + sigma_t * noise
 
        # Score差异驱动更新 (Eq. 2)
        s_real = real_scorer(x_t, t)  # 冻结的真实数据score
        s_fake = fake_scorer(x_t, t)  # 在线更新的fake数据score
 
        # DMD gradient
        grad = (s_real - s_fake) * (d_fake_video / d_theta)
        loss = (grad * fake_video).sum()
 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
 
    return student
 
 
# ============================================================
# Inference: Real-time Streaming Generation
# ============================================================
def realtime_inference(model, prompt, num_frames=81, chunk_size=3):
    """
    实时流式推理: 每次生成一个chunk(3帧)
    单张H100/4090即可实时运行
    """
    frames = []
 
    for chunk_start in range(0, num_frames, chunk_size):
        prefix = torch.stack(frames) if frames else None
 
        # 从噪声开始, 1步生成当前chunk
        noise = torch.randn(1, chunk_size, C, H, W)
        chunk_frames = model(noise, prefix, prompt, t=1.0)  # 单步去噪
 
        frames.extend([chunk_frames[:, j] for j in range(chunk_size)])
 
        # 可以立即显示给用户 (streaming)
        yield chunk_frames

4. Experimental Setup (实验设置)

4.1 实现细节

  • 基座模型: Wan2.1-T2V-1.3B,生成 81 帧 832x480 视频
  • AR 训练: 在 3K 合成数据集 D_Bi 上训练 2K 步
  • Causal ODE: 采样 3K causal ODE 轨迹 D_Causal,训练 1K 步
  • DMD: 使用 VidProM 数据集,训练 750 步
  • 推理: chunk-wise(每 chunk 3 帧),单步生成,单卡 H100 实时
  • 吞吐量: 17.0 FPS,延迟 0.69 秒

4.2 代码映射 (GitHub Repo)

仓库地址: https://github.com/thu-ml/Causal-Forcing

论文概念代码位置说明
整体训练入口train.py支持三阶段训练的统一入口
AR Diffusion (Stage 1)trainer/ + configs/Teacher Forcing 训练 AR 扩散模型
Causal ODE 数据生成get_causal_ode_data_framewise.py, get_causal_ode_data_chunkwise.py从 AR 教师采样 PF-ODE 轨迹生成配对数据
ODE 蒸馏 + DMDtrainer/Causal ODE 初始化 + 非对称 DMD 训练
模型架构model/基于 Wan2.1 的 DiT 架构,支持 causal attention
推理流水线inference.py, pipeline/chunk-wise 流式生成
Causal CD 扩展configs/ 中 CD 相关配置Causal Consistency Distillation
基座模型集成wan/Wan2.1-T2V-1.3B 的适配层
长视频生成long_video/分钟级视频生成扩展

5. Experimental Results (实验结果)

5.1 定量比较

模型吞吐量↑延迟↓VBench Total↑Dynamic Degree↑VisionReward↑Instruct Follow↑用户评分↓
Wan2.1-1.3B0.7810383.37615.275422.29
LTX-1.9B8.9813.579.8346-6.218-386.40
NOVA0.884.180.3146-7.381-168.41
Pyramid Flow6.702.580.75164.055-26.11
CausVid17.00.6981.33625.741124.27
Self Forcing17.00.6983.74575.820482.87
Causal Forcing (Ours)17.00.6984.04686.326561.64

核心结论:

  • 相同训练预算和推理延迟下,全面超越 Self Forcing
  • Dynamic Degree 提升 19.3%,VisionReward 提升 8.7%,Instruction Following 提升 16.7%
  • 比双向模型 Wan2.1 吞吐量高 2079%,质量相当甚至超越

5.2 消融研究

Figure 2 解读: 验证 DMD 无法弥合架构鸿沟。即使用标准 DMD 初始化 AR 学生(消除了 sampling-step gap,只保留 architectural gap),性能仍显著低于标准 DMD 蒸馏的双向学生。这证明架构鸿沟必须在 ODE 蒸馏阶段解决,而非依赖后续的 DMD。

消融结论 (Table 2):

  • AR 训练策略: Teacher Forcing 全面优于 Diffusion Forcing(VisionReward 提升 111.2%)
  • ODE 初始化: Causal ODE 初始化大幅优于 Self Forcing 的 ODE 初始化(chunk-wise: VisionReward +90.0%, Dynamic Degree +183.3%)
  • Consistency Distillation: Causal CD 优于 Asymmetric CD

5.3 定性比较

Figure 6 解读: 与 Wan2.1、CausVid、Self Forcing 的定性对比。Causal Forcing 在动态程度和视觉质量上显著优于现有蒸馏方法,同时匹配甚至超越双向扩散模型 Wan2.1。

5.4 Figure 解读汇总

Figure 1 解读: 展示现有方法的局限性。从同一个双向基座模型蒸馏时,SOTA 的 AR 蒸馏方法 Self Forcing 仍然显著落后于标准 DMD(蒸馏出双向学生)。

Figure 7 解读: 这是 injectivity 示意图的第一张补充子图,进一步可视化帧级单射性。(a) 双向 flow map 在帧级不满足单射性——同一 noisy frame 在不同 video context 下映射到不同 clean frame。

Figure 8 解读: 这是 injectivity 示意图的第二张补充子图,进一步可视化帧级单射性。(c) AR flow map 天然满足帧级单射性。

Figure 9 解读: Causal ODE with Bidirectional Init。即使用双向模型初始化但采用 causal ODE 数据也能获得不错效果,验证了关键在于 ODE 数据来自 AR 教师。