Causal Forcing: Autoregressive Diffusion Distillation Done Right for High-Quality Real-Time Interactive Video Generation
Authors: Hongzhou Zhu, Min Zhao, Guande He, Hang Su, Chongxuan Li, Jun Zhu Affiliations: 清华大学, ShengShu, UT Austin, 人大 arXiv: 2602.02214 GitHub: thu-ml/Causal-Forcing Year: 2026
1. Motivation (研究动机)
核心问题:实时交互式视频生成
当前实时交互式视频生成的主流方案是将预训练的双向(bidirectional)扩散模型蒸馏为少步自回归(AR)模型。这个过程面临一个根本性挑战——架构鸿沟(architectural gap):
- 双向模型: 使用 full attention,每帧都能看到所有其他帧(过去+未来)
- AR模型: 使用 causal attention,每帧只能看到过去的帧
现有方法(如 Self Forcing)采用两阶段蒸馏:
- ODE 蒸馏初始化: 从双向教师蒸馏到 AR 学生
- DMD 阶段: 进一步提升质量
核心发现:Self Forcing 的 ODE 蒸馏在理论上存在根本缺陷
Figure 3 解读: 这是本文最核心的图。展示了三种 ODE 蒸馏设置下的 injectivity 性质:
- (a) 双向教师 → 双向学生: 标准 DMD 设置。Injectivity 在视频级别成立——不同的 noisy video 映射到不同的 clean video。学生能准确学到 flow map。
- (b) AR 教师 → AR 学生 (Ours): Injectivity 在帧级别成立——每个 noisy frame 映射到唯一的 clean frame。这正是 AR 学生所需要的。
- (c) 双向教师 → AR 学生 (Self Forcing): Injectivity 被违反!同一个 noisy frame 可能映射到多个不同的 clean frame(因为双向模型的 flow map 依赖其他帧的噪声),导致学生学到的是条件均值 E[x₀|xᵢₜ],产生模糊输出。
2. Idea (核心思想)
2.1 核心思路:让蒸馏数据与学生架构一致
关键思想: 使用 Stage 1 得到的 AR 扩散模型(而非双向模型)作为 ODE 蒸馏的教师。
直观上,方法的目标不是继续让双向 teacher 去“迁就” causal student,而是直接让 teacher 先具备 causal 结构,再做 ODE / DMD 蒸馏。这样可以把架构鸿沟前移到更合适的阶段处理。
2.2 为什么这样能解决问题?
定义 Frame-level injectivity (Definition 3.1): 对于映射 φ^{AR}: (xᵢₜ, t) → xᵢ₀,如果对任意两个 noisy video:
即 AR 教师的 PF-ODE flow map 天然满足帧级单射性——因为 AR 模型只依赖过去帧,第 i 帧的去噪结果完全由 (xᵢₜ, x₀^{<i}) 决定,不受未来帧影响。
2.3 训练目标的直观含义
其中 (xᵢₜ, xᵢ₀) 来自 AR 教师的 PF-ODE 轨迹,x_{gt}^{<i} 是真实干净前缀帧。这个目标对应的是“把 noisy frame 在正确的 causal 条件下还原为 clean frame”。
2.4 三阶段流水线概览
Causal Forcing 采用三阶段方案:
Stage 1: AR Diffusion Training (Teacher Forcing)
→ 获得 AR 扩散教师模型
Stage 2: Causal ODE Distillation
→ 用 AR 教师生成配对数据,蒸馏 AR 学生
Stage 3: Asymmetric DMD
→ 进一步提升为少步生成模型
3. Method (方法)
3.1 Stage 1: 自回归扩散训练——为什么选 Teacher Forcing 而非 Diffusion Forcing?
Figure 4 解读: 对比 Diffusion Forcing (DF) 和 Teacher Forcing (TF) 训练 AR 扩散模型的效果。DF 训练时前缀帧是含噪的(x_t^{<i}),推理时前缀是干净的(x_0^{<i}),存在训练-推理分布不匹配,导致视频崩溃(上排)。TF 训练时前缀帧就是干净帧,与推理一致,生成质量更高(下排)。
理论支撑 (Proposition 3.4):
即 Diffusion Forcing 学到的条件分布与真实条件分布之间有不可消除的 KL 散度。
3.2 Stage 2: Causal ODE 蒸馏——核心创新
关键思想: 使用 Stage 1 得到的 AR 扩散模型(而非双向模型)作为 ODE 蒸馏的教师。
为什么这解决了问题?
定义 Frame-level injectivity (Definition 3.1): 对于映射 φ^{AR}: (xᵢₜ, t) → xᵢ₀,如果对任意两个 noisy video:
即 AR 教师的 PF-ODE flow map 天然满足帧级单射性——因为 AR 模型只依赖过去帧,第 i 帧的去噪结果完全由 (xᵢₜ, x₀^{<i}) 决定,不受未来帧影响。
训练目标:
其中 (xᵢₜ, xᵢ₀) 来自 AR 教师的 PF-ODE 轨迹,x_{gt}^{<i} 是真实干净前缀帧。
3.3 Stage 3: Asymmetric DMD
使用 causal ODE 初始化后的模型进行标准的非对称 DMD 训练,使用 VidProM 数据集,训练 750 步至收敛。
Figure 5 解读: Self Forcing 的 ODE 初始化 + DMD 仍然产生较弱的动态效果和伪影(上排),而 Causal Forcing 的 ODE 初始化 + DMD 产生了更强的动态效果和更高的视觉保真度(下排)。这证明 causal ODE 蒸馏确实为 DMD 提供了正确的初始化。
3.4 扩展:Causal Consistency Distillation
本文还将思想扩展到 Consistency Distillation (CD),用 AR 教师替代双向教师:
3.5 Python 伪代码
import torch
import torch.nn.functional as F
# ============================================================
# Stage 1: AR Diffusion Training via Teacher Forcing
# ============================================================
def train_ar_diffusion_teacher_forcing(model, dataloader, num_steps=2000):
"""
训练AR扩散模型,使用Teacher Forcing策略
- 训练时前缀帧为干净帧(与推理一致)
- 使用causal attention mask
"""
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
for step in range(num_steps):
video = next(dataloader) # [B, N, C, H, W], N=81 frames
# 随机选择要训练的帧索引
i = torch.randint(1, N, (B,)) # 第i帧
# 前缀帧: 干净的 (Teacher Forcing的关键)
prefix_clean = video[:, :i] # x_0^{<i}, 干净帧作为条件
target_clean = video[:, i] # x_0^i, 目标帧
# 对目标帧加噪
t = torch.rand(B) * T # 随机时间步
noise = torch.randn_like(target_clean)
x_t_i = alpha_t * target_clean + sigma_t * noise # noisy target frame
# 模型预测 (causal attention: 只看过去帧)
# 输入: [prefix_clean | x_t_i], 带causal mask
v_pred = model(x_t_i, prefix_clean, t) # velocity prediction
# Flow matching loss
v_target = target_clean - noise # 或其他参数化
loss = F.mse_loss(v_pred, v_target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return model # AR diffusion teacher
# ============================================================
# Stage 2: Causal ODE Distillation
# ============================================================
def generate_causal_ode_pairs(ar_teacher, dataloader, timesteps_S):
"""
从AR教师模型采样PF-ODE轨迹,生成配对训练数据
关键: 使用AR教师(而非双向教师)保证frame-level injectivity
"""
paired_data = []
for video_clean in dataloader:
for i in range(1, N):
prefix_clean = video_clean[:, :i] # x_{gt}^{<i}
# 从纯噪声开始,沿AR教师的PF-ODE采样
x_T_i = torch.randn(B, C, H, W) # 初始噪声
# 存储轨迹上多个时间步的中间态
trajectory = {}
x_t = x_T_i
for t in reversed(sorted(timesteps_S | {0})):
trajectory[t] = x_t.clone()
if t > 0:
# ODE step: dx = v_teacher(x_t, prefix_clean, t) dt
x_t = ode_step(ar_teacher, x_t, prefix_clean, t)
x_0_i = trajectory[0] # 教师生成的clean frame
# 配对: {(x_t^i, t) -> x_0^i} for all t in S
for t in timesteps_S:
paired_data.append((trajectory[t], prefix_clean, t, x_0_i))
return paired_data
def train_causal_ode_student(student, paired_data, num_steps=1000):
"""
Causal ODE蒸馏: 学生从noisy frame直接回归clean frame
因为AR教师满足frame-level injectivity,学生能准确学到flow map
"""
optimizer = torch.optim.AdamW(student.parameters(), lr=1e-5)
for step in range(num_steps):
x_t_i, prefix_clean, t, x_0_i = sample_batch(paired_data)
# 学生直接预测clean frame
x_0_pred = student(x_t_i, prefix_clean, t)
# MSE回归目标 (Eq. 8)
loss = F.mse_loss(x_0_pred, x_0_i)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return student
# ============================================================
# Stage 3: Asymmetric DMD
# ============================================================
def train_asymmetric_dmd(student, bi_teacher, real_scorer, fake_scorer,
dataloader, num_steps=750):
"""
非对称DMD: 使用双向基座模型作为score teacher
student已经由causal ODE蒸馏初始化(关键!)
"""
optimizer = torch.optim.AdamW(student.parameters(), lr=1e-6)
for step in range(num_steps):
prompt = next(dataloader)
# 学生生成fake video (few-step AR generation)
fake_video = student.generate(prompt, num_steps=1) # 1-step generation
# 对fake video加噪
t = torch.rand(B) * T
noise = torch.randn_like(fake_video)
x_t = alpha_t * fake_video + sigma_t * noise
# Score差异驱动更新 (Eq. 2)
s_real = real_scorer(x_t, t) # 冻结的真实数据score
s_fake = fake_scorer(x_t, t) # 在线更新的fake数据score
# DMD gradient
grad = (s_real - s_fake) * (d_fake_video / d_theta)
loss = (grad * fake_video).sum()
optimizer.zero_grad()
loss.backward()
optimizer.step()
return student
# ============================================================
# Inference: Real-time Streaming Generation
# ============================================================
def realtime_inference(model, prompt, num_frames=81, chunk_size=3):
"""
实时流式推理: 每次生成一个chunk(3帧)
单张H100/4090即可实时运行
"""
frames = []
for chunk_start in range(0, num_frames, chunk_size):
prefix = torch.stack(frames) if frames else None
# 从噪声开始, 1步生成当前chunk
noise = torch.randn(1, chunk_size, C, H, W)
chunk_frames = model(noise, prefix, prompt, t=1.0) # 单步去噪
frames.extend([chunk_frames[:, j] for j in range(chunk_size)])
# 可以立即显示给用户 (streaming)
yield chunk_frames4. Experimental Setup (实验设置)
4.1 实现细节
- 基座模型: Wan2.1-T2V-1.3B,生成 81 帧 832x480 视频
- AR 训练: 在 3K 合成数据集 D_Bi 上训练 2K 步
- Causal ODE: 采样 3K causal ODE 轨迹 D_Causal,训练 1K 步
- DMD: 使用 VidProM 数据集,训练 750 步
- 推理: chunk-wise(每 chunk 3 帧),单步生成,单卡 H100 实时
- 吞吐量: 17.0 FPS,延迟 0.69 秒
4.2 代码映射 (GitHub Repo)
仓库地址: https://github.com/thu-ml/Causal-Forcing
| 论文概念 | 代码位置 | 说明 |
|---|---|---|
| 整体训练入口 | train.py | 支持三阶段训练的统一入口 |
| AR Diffusion (Stage 1) | trainer/ + configs/ | Teacher Forcing 训练 AR 扩散模型 |
| Causal ODE 数据生成 | get_causal_ode_data_framewise.py, get_causal_ode_data_chunkwise.py | 从 AR 教师采样 PF-ODE 轨迹生成配对数据 |
| ODE 蒸馏 + DMD | trainer/ | Causal ODE 初始化 + 非对称 DMD 训练 |
| 模型架构 | model/ | 基于 Wan2.1 的 DiT 架构,支持 causal attention |
| 推理流水线 | inference.py, pipeline/ | chunk-wise 流式生成 |
| Causal CD 扩展 | configs/ 中 CD 相关配置 | Causal Consistency Distillation |
| 基座模型集成 | wan/ | Wan2.1-T2V-1.3B 的适配层 |
| 长视频生成 | long_video/ | 分钟级视频生成扩展 |
5. Experimental Results (实验结果)
5.1 定量比较
| 模型 | 吞吐量↑ | 延迟↓ | VBench Total↑ | Dynamic Degree↑ | VisionReward↑ | Instruct Follow↑ | 用户评分↓ |
|---|---|---|---|---|---|---|---|
| Wan2.1-1.3B | 0.78 | 103 | 83.37 | 61 | 5.275 | 42 | 2.29 |
| LTX-1.9B | 8.98 | 13.5 | 79.83 | 46 | -6.218 | -38 | 6.40 |
| NOVA | 0.88 | 4.1 | 80.31 | 46 | -7.381 | -16 | 8.41 |
| Pyramid Flow | 6.70 | 2.5 | 80.75 | 16 | 4.055 | -2 | 6.11 |
| CausVid | 17.0 | 0.69 | 81.33 | 62 | 5.741 | 12 | 4.27 |
| Self Forcing | 17.0 | 0.69 | 83.74 | 57 | 5.820 | 48 | 2.87 |
| Causal Forcing (Ours) | 17.0 | 0.69 | 84.04 | 68 | 6.326 | 56 | 1.64 |
核心结论:
- 在相同训练预算和推理延迟下,全面超越 Self Forcing
- Dynamic Degree 提升 19.3%,VisionReward 提升 8.7%,Instruction Following 提升 16.7%
- 比双向模型 Wan2.1 吞吐量高 2079%,质量相当甚至超越
5.2 消融研究
Figure 2 解读: 验证 DMD 无法弥合架构鸿沟。即使用标准 DMD 初始化 AR 学生(消除了 sampling-step gap,只保留 architectural gap),性能仍显著低于标准 DMD 蒸馏的双向学生。这证明架构鸿沟必须在 ODE 蒸馏阶段解决,而非依赖后续的 DMD。
消融结论 (Table 2):
- AR 训练策略: Teacher Forcing 全面优于 Diffusion Forcing(VisionReward 提升 111.2%)
- ODE 初始化: Causal ODE 初始化大幅优于 Self Forcing 的 ODE 初始化(chunk-wise: VisionReward +90.0%, Dynamic Degree +183.3%)
- Consistency Distillation: Causal CD 优于 Asymmetric CD
5.3 定性比较

Figure 6 解读: 与 Wan2.1、CausVid、Self Forcing 的定性对比。Causal Forcing 在动态程度和视觉质量上显著优于现有蒸馏方法,同时匹配甚至超越双向扩散模型 Wan2.1。
5.4 Figure 解读汇总
Figure 1 解读: 展示现有方法的局限性。从同一个双向基座模型蒸馏时,SOTA 的 AR 蒸馏方法 Self Forcing 仍然显著落后于标准 DMD(蒸馏出双向学生)。
Figure 7 解读: 这是 injectivity 示意图的第一张补充子图,进一步可视化帧级单射性。(a) 双向 flow map 在帧级不满足单射性——同一 noisy frame 在不同 video context 下映射到不同 clean frame。
Figure 8 解读: 这是 injectivity 示意图的第二张补充子图,进一步可视化帧级单射性。(c) AR flow map 天然满足帧级单射性。
Figure 9 解读: Causal ODE with Bidirectional Init。即使用双向模型初始化但采用 causal ODE 数据也能获得不错效果,验证了关键在于 ODE 数据来自 AR 教师。