CausVid: From Slow Bidirectional to Fast Autoregressive Video Diffusion Models
Authors: Tianwei Yin*, Qiang Zhang*, Richard Zhang, William T. Freeman, Fredo Durand, Eli Shechtman, Xun Huang Affiliations: MIT, Adobe Venue: CVPR 2025
1. Motivation
1.1 问题背景
当前主流的视频扩散模型(如 CogVideoX、MovieGen)都采用 Diffusion Transformer (DiT) 架构,使用 双向注意力 (bidirectional attention) 跨所有帧进行全序列去噪。这带来两个核心瓶颈:
- 高延迟:生成单帧需要处理整个序列(包括未来帧),用户必须等待全部帧生成完毕才能看到结果。例如生成一个 128 帧视频需要 219 秒。
- 无法流式生成:双向注意力使得当前帧依赖未来帧,不支持 streaming 和交互式应用(如实时 video-to-video、动态 prompting)。
- 计算量随帧数二次增长:attention 的计算和内存开销随帧数平方增长,使长视频生成极其昂贵。
1.2 现有方案的局限
- 自回归视频扩散模型:虽然能流式生成,但 (a) 现有 autoregressive 扩散模型质量远不及双向模型,(b) 误差累积导致长视频质量快速退化,(c) 每步仍需 50 步去噪,速度依然不够。
- 蒸馏方法:现有视频蒸馏(如 consistency distillation、adversarial distillation)大多 (a) 只处理短片段(<2 秒),(b) 将 non-causal teacher 蒸馏为 non-causal student,不改变流式生成的根本限制。
1.3 核心矛盾
如何同时获得 双向模型的高质量 和 自回归模型的低延迟 + 流式生成能力?直接将双向模型转为因果模型并蒸馏效果很差——因果 teacher 本身就弱于双向 teacher,误差会传递给 student。
Figure 1 解读:顶部是传统双向扩散模型,延迟 219 秒生成 128 帧视频,用户需等待全部完成。底部是 CausVid 的因果 student 模型,初始延迟仅 1.3 秒即可开始输出第一批帧,之后以 9.4 FPS 流式持续生成。通过非对称蒸馏(Asymmetric Distillation with DMD),将双向 teacher 的知识迁移到因果 student,同时将 50 步去噪压缩为 4 步。
2. Idea
2.1 核心思路:非对称蒸馏
CausVid 的关键洞察是:不要从因果 teacher 蒸馏到因果 student,而是用双向 teacher 监督因果 student。这种”非对称”(asymmetric)蒸馏策略有两大优势:
- 双向 teacher 质量更高,能为因果 student 提供更强的监督信号
- DMD 损失在分布层面做匹配(而非逐样本),天然允许 teacher 和 student 有不同的架构(双向 vs 因果)
2.2 技术路线
整体方案分三步:
- 架构转换:将预训练双向 DiT 改为 block-wise causal attention(chunk 内双向,chunk 间因果),并用去噪损失微调
- ODE 初始化:用双向 teacher 生成 ODE 轨迹数据集,预训练因果 student 做回归,稳定后续蒸馏
- DMD 蒸馏:用双向 teacher 的 score function 作为 ,在线训练因果 student 的 score function ,通过 reverse KL 散度梯度更新 student
2.3 推理加速:KV Caching
因果架构天然支持 KV caching——已生成 chunk 的 key/value 可以缓存复用,新 chunk 只需计算自身的 attention,实现流式生成。
Figure 2 解读:CausVid 支持的多种视频生成任务。(1) Text-to-Video:纯文本驱动生成高质量视频;(2) Image-to-Video:给定首帧图像 + 文本,自回归扩展为视频(zero-shot,无需额外训练);(3) Video-to-Video:将游戏引擎渲染的简单输入实时转换为逼真视频(流式处理);(4) Dynamic Prompting:在视频生成过程中随时更换文本 prompt,实现剧情动态变化。所有任务均受益于低延迟的因果生成架构。
3. Method
3.1 自回归扩散 Transformer 架构
3D VAE 编码:视频首先通过 3D VAE 压缩到潜空间。VAE encoder 将每 16 帧视频压缩为 5 个 latent frames(一个 chunk),空间和时间维度同时压缩。
Block-wise Causal Attention:在标准 DiT 的 self-attention 层上施加 block-wise causal mask:
其中 是帧索引, 是 chunk size。chunk 内各帧双向可见(捕获局部时序依赖),chunk 间只能看到过去(确保因果性)。这样既保持了因果性用于流式推理,又保留了局部双向注意力提升质量。
Per-chunk 独立噪声时间步:遵循 Diffusion Forcing 的设计,每个 chunk 有独立采样的噪声时间步 ,支持不同 chunk 处于不同去噪阶段。
Figure 5 解读:自回归扩散 Transformer 架构示意图。三个 noisy frame chunk 输入 AR-DiT,注意力遵循 block-wise causal mask——当前 chunk 的 token 可以 attend 到自身和所有之前 chunk 的 token,但不能 attend 到未来 chunk。每个 chunk 有独立的噪声 level,模型在去噪过程中逐 chunk 自回归生成。
3.2 非对称蒸馏(Bidirectional Teacher → Causal Student)
为什么需要非对称? 直接微调双向 DiT 为因果模型后做蒸馏(causal teacher → causal student)效果较差:因果 teacher 质量本身不如双向模型,且误差会从 teacher 传递到 student 并在自回归推理中累积放大。
DMD 损失函数:
其中:
- :因果 student generator(4 步去噪)
- :双向 teacher 的 score function(冻结)
- :student 输出分布的 score function(在线训练)
- :前向扩散加噪过程
关键点: 使用双向注意力, 使用因果注意力,DMD 在分布层面做匹配,允许这种架构不对称。
训练流程(Algorithm 1):
- 采样视频,分为 个 chunk
- 每个 chunk 独立采样噪声时间步,加噪
- Student 用因果 mask 预测 clean frames
- 对 加噪到随机时间步
- 用 DMD loss( vs )更新 student
- 用 student 输出训练 的 denoising loss
Figure 6 解读:训练流程两阶段。上方 Student Initialization:用双向 teacher 运行 ODE solver 生成 noise → clean video 的轨迹对,因果 student 在这些 ODE 对上做 MSE 回归预训练,学习基本的去噪映射。下方 Asymmetric Distillation with DMD:从数据集采样视频加噪后,因果 student 生成 clean 预测,然后重新加噪;双向 teacher 的 score function 提供”真实分布”方向的梯度,在线 score function 提供”生成分布”方向的梯度,二者之差驱动 student 向真实分布对齐。
3.3 Student 初始化(ODE Regression)
直接用 DMD loss 训练因果 student 不稳定(架构差异大)。解决方案:
- 用双向 teacher 的 ODE solver 生成轨迹 (从纯噪声 到干净数据 )
- 选取与 student 推理时间步匹配的子集
- Student 在这些 ODE 对上做回归预训练:
这一步只需约 1000 个 ODE 对,训练 3000 步即可收敛,为后续 DMD 蒸馏提供稳定的起点。
3.4 KV Caching 推理
推理时采用 KV caching 实现高效流式生成(Algorithm 2):
- 初始化 KV cache 为空
- 对每个 chunk :
- 从纯噪声 开始
- 执行 4 步迭代去噪,每步利用 cache 中之前 chunk 的 KV pairs
- 去噪完成后,对 clean 结果 做一次前向传播计算 KV pairs 并追加到 cache
- 返回所有 chunk 的 clean 结果
由于使用 KV caching,每步去噪只需计算当前 chunk 的 attention(不需要重新计算所有历史帧),实现了 9.4 FPS 的流式生成速度。值得注意的是,在推理阶段可以使用标准的 FlashAttention 加速(因为因果性通过 KV cache 隐式保证,不需要显式 causal mask)。
4. Experimental Setup (实验设置)
4.1 实验设置
- Teacher 模型:双向 DiT(架构类似 CogVideoX),50 步去噪,训练在 352x640、12 FPS、10 秒视频上
- Student 模型:同一架构 + block-wise causal attention mask,4 步去噪(时间步 [999, 748, 502, 247])
- 训练数据:约 400K 单镜头视频(含图像和视频混合数据集),经安全和美学过滤
- 训练流程:ODE 初始化 3000 步 → DMD 蒸馏 6000 步,AdamW 优化器,64 张 H100 GPU 约 2 天
- 评估指标:VBench(16 项指标:temporal quality, frame quality, text alignment)
5. Experimental Results (实验结果)
5.1 Text-to-Short-Video(5-10 秒)
| Method | Length(s) | Temporal Quality | Frame Quality | Text Alignment |
|---|---|---|---|---|
| CogVideoX-5B | 6 | 89.9 | 59.8 | 29.1 |
| OpenSORA | 8 | 88.4 | 52.0 | 28.4 |
| Pyramid Flow | 10 | 89.6 | 55.9 | 27.1 |
| MovieGen | 10 | 91.5 | 61.1 | 28.8 |
| CausVid (Ours) | 10 | 94.7 | 64.4 | 30.1 |
CausVid 在三个核心维度上全面超越所有基线,temporal quality 达到 94.7(运动一致性和动态质量最优)。
5.2 Text-to-Long-Video(~30 秒)
| Method | Temporal Quality | Frame Quality | Text Alignment |
|---|---|---|---|
| Gen-L-Video | 86.7 | 52.3 | 28.7 |
| FreeNoise | 86.2 | 54.8 | 28.7 |
| StreamingT2V | 89.2 | 46.1 | 27.2 |
| FIFO-Diffusion | 93.1 | 57.9 | 29.9 |
| Pyramid Flow | 89.0 | 48.3 | 24.4 |
| CausVid (Ours) | 94.9 | 63.4 | 28.9 |
长视频生成中 CausVid 同样以大幅优势领先,尤其在 temporal quality 和 frame quality 上。VBench-Long 总分达到 84.27,排名第一。
5.3 延迟与吞吐量
| Method | Latency (s) | Throughput (FPS) |
|---|---|---|
| CogVideoX-5B | 208.6 | 0.6 |
| Pyramid Flow | 6.7 | 2.5 |
| Bidirectional Teacher | 219.2 | 0.6 |
| CausVid (Ours) | 1.3 | 9.4 |
CausVid 实现 160x 延迟降低 和 16x 吞吐量提升(相比同架构双向 teacher)。初始延迟仅 1.3 秒即可开始流式输出。
5.4 消融实验
Figure 8 解读:30 秒视频生成过程中各方法的帧质量随时间变化曲线。橙色是因果 teacher(直接微调双向模型为因果),质量快速退化(误差累积)。绿色是”CausVid with Causal Teacher”(因果 teacher 蒸馏因果 student),退化稍慢但仍显著。蓝色是完整 CausVid(双向 teacher 蒸馏因果 student),质量在 30 秒内保持稳定,几乎无退化。FIFO-Diffusion(紫色)也能保持质量但延迟高得多。这证明了非对称蒸馏是抑制误差累积的关键。
关键消融结论(Table 4):
| ODE Init. | Teacher | Temporal Quality | Frame Quality | Text Alignment |
|---|---|---|---|---|
| ✗ | Bidirectional | 93.4 | 60.6 | 29.4 |
| ✓ | None | 92.9 | 48.1 | 25.3 |
| ✓ | Causal | 91.9 | 61.7 | 28.2 |
| ✓ | Bidirectional | 94.7 | 64.4 | 30.1 |
- 双向 teacher > 因果 teacher > 无 teacher:双向 teacher 的监督最有效
- ODE 初始化有帮助:提升训练稳定性和最终质量
- 组合效果最优:ODE 初始化 + 双向 teacher DMD 蒸馏
5.5 应用
Streaming Video-to-Video Translation:
| Method | Temporal Quality | Frame Quality | Text Alignment |
|---|---|---|---|
| StreamV2V | 92.5 | 59.3 | 26.9 |
| CausVid (Ours) | 93.2 | 61.7 | 27.7 |
通过 SDEdit 机制(对输入视频 chunk 加噪到 再一步去噪),实现流式视频风格转换。
Image-to-Video(Zero-shot):将输入图像复制为第一个 chunk 的所有帧,后续 chunk 自回归生成。在 VBench-I2V 上 temporal quality 达 92.0,frame quality 65.0,超越 CogVideoX-5B 和 Pyramid Flow。
Dynamic Prompting:生成过程中可随时更换文本 prompt,实现场景和动作的动态切换。
Figure 7 解读:用户偏好研究结果。CausVid 在与 MovieGen(60.9% vs 39.1%)、CogVideoX(69.0% vs 31.0%)、Pyramid Flow(61.4% vs 17.6%,剩余持平)的对比中均获得多数偏好。与自己的双向 teacher 对比(51.7% vs 48.3%)几乎持平,说明蒸馏后的 4 步因果 student 质量逼近 50 步双向 teacher。
6. Pseudocode
5.1 Block-wise Causal Attention Mask
import torch
def create_blockwise_causal_mask(num_frames: int, chunk_size: int) -> torch.Tensor:
"""
创建 block-wise causal attention mask.
chunk 内双向可见,chunk 间只能看到过去。
Args:
num_frames: 总帧数 N
chunk_size: 每个 chunk 的帧数 k
Returns:
mask: (N, N) 二值矩阵, 1=可见, 0=遮挡
"""
mask = torch.zeros(num_frames, num_frames)
for i in range(num_frames):
for j in range(num_frames):
if j // chunk_size <= i // chunk_size:
mask[i, j] = 1.0
return mask # (N, N)5.2 Asymmetric DMD Training
import torch
import torch.nn.functional as F
def asymmetric_dmd_training_step(
student_G: "CausalDiT", # 因果 student generator (4-step)
s_data: "BidirectionalDiT", # 双向 teacher score function (frozen)
s_gen: "CausalDiT", # generator 分布的 score function (online)
video_batch: torch.Tensor, # (B, N, C, H, W) clean video
chunk_size: int,
noise_schedule: "NoiseSchedule",
timesteps_student: list, # e.g., [999, 748, 502, 247]
):
"""
一步非对称 DMD 蒸馏训练。
核心:用双向 teacher 的 s_data 提供"真实分布"梯度,
用在线训练的 s_gen 提供"生成分布"梯度,
二者之差驱动因果 student 向真实分布对齐。
"""
B, N, C, H, W = video_batch.shape
L = N // chunk_size # chunk 数量
# --- Step 1: Student 生成 ---
# 每个 chunk 独立采样噪声时间步
t_per_chunk = torch.randint(0, len(timesteps_student), (B, L)) # (B, L)
noise = torch.randn_like(video_batch) # (B, N, C, H, W)
# 对每个 chunk 加对应时间步的噪声
noisy_video = add_noise_per_chunk(video_batch, noise, t_per_chunk,
chunk_size, noise_schedule)
# Student 用 block-wise causal mask 预测 clean video
x0_pred = student_G(noisy_video, t_per_chunk) # (B, N, C, H, W)
# --- Step 2: DMD Loss 计算 ---
# 对预测结果加随机噪声(用于 score function 评估)
t_dmd = torch.randint(0, 1000, (B,)) # 随机采样单一时间步
noise_dmd = torch.randn_like(x0_pred)
x_t = noise_schedule.add_noise(x0_pred, noise_dmd, t_dmd) # (B, N, C, H, W)
# 双向 teacher score(frozen):真实数据分布的梯度方向
with torch.no_grad():
score_data = s_data(x_t, t_dmd) # 双向 attention, 全帧可见
# Generator score(online):当前生成分布的梯度方向
score_gen = s_gen(x_t, t_dmd) # 因果 attention
# DMD gradient: score_data - score_gen 驱动 student 更新
# 实际实现中通过 score difference 作为伪梯度反传
dmd_loss = compute_dmd_gradient(x0_pred, score_data, score_gen, t_dmd)
# --- Step 3: 更新 s_gen 的 denoising loss ---
# 用 student 的输出作为"数据"训练 s_gen
with torch.no_grad():
x0_detached = x0_pred.detach()
noise_new = torch.randn_like(x0_detached)
t_new = torch.randint(0, 1000, (B,))
x_t_new = noise_schedule.add_noise(x0_detached, noise_new, t_new)
eps_pred = s_gen(x_t_new, t_new)
s_gen_loss = F.mse_loss(eps_pred, noise_new) # 标准 denoising loss
return dmd_loss, s_gen_loss5.3 KV Caching Inference
import torch
def autoregressive_inference_with_kv_cache(
student_G: "CausalDiT",
prompt_embedding: torch.Tensor,
num_chunks: int,
chunk_size: int,
timesteps: list = [999, 748, 502, 247], # 4 步去噪
):
"""
Algorithm 2: 使用 KV Caching 的流式自回归推理。
核心:每个 chunk 去噪完成后,计算其 KV pairs 并缓存;
后续 chunk 去噪时复用历史 KV,避免重复计算。
"""
kv_cache = None # 初始化为空
all_chunks = []
for i in range(num_chunks):
# 从纯噪声开始
x_i = torch.randn(1, chunk_size, C, H, W) # 当前 chunk
# 4 步迭代去噪
for j in range(len(timesteps) - 1, -1, -1):
t_j = timesteps[j]
# 使用 KV cache 中之前 chunk 的信息
x0_pred = student_G.forward_with_cache(x_i, t_j, kv_cache)
if j > 0:
# DDIM-style 更新到下一个时间步
t_prev = timesteps[j - 1]
x_i = ddim_step(x0_pred, x_i, t_j, t_prev)
else:
x_i = x0_pred # 最后一步直接输出
# 去噪完成,更新 KV cache
# 对 clean chunk 做一次 t=0 的前向传播,获取 KV pairs
with torch.no_grad():
new_kv = student_G.compute_kv(x_i, t=0)
kv_cache = append_to_cache(kv_cache, new_kv)
all_chunks.append(x_i)
# 此时可以立即输出/显示当前 chunk(流式)
return torch.cat(all_chunks, dim=1) # (1, N, C, H, W)7. Code Mapping
GitHub 仓库: tianweiy/CausVid
| 论文概念 | 代码位置 | 说明 |
|---|---|---|
| Block-wise Causal DiT | causvid/models/wan/causal_model.py | 因果注意力 mask 实现,chunk 内双向、chunk 间因果 |
| Bidirectional DiT Wrapper | causvid/models/wan/wan_wrapper.py | 双向 teacher 模型封装 |
| DMD Loss | causvid/dmd.py + causvid/loss.py | Distribution Matching Distillation 损失计算,包含 score function 差值梯度 |
| Asymmetric Distillation Training | causvid/train_distillation.py | 主训练循环:因果 student + 双向 teacher 的非对称蒸馏 |
| ODE Trajectory Generation | causvid/models/wan/generate_ode_pairs.py | 用双向 teacher 的 ODE solver 生成初始化数据 |
| ODE Regression (Student Init) | causvid/train_ode.py + causvid/ode_regression.py | Student 在 ODE 对上的预训练回归 |
| ODE Data Pipeline | causvid/ode_data/create_lmdb_iterative.py | ODE 轨迹数据的 LMDB 存储和读取 |
| Bidirectional Trajectory Pipeline | causvid/bidirectional_trajectory_pipeline.py | 双向模型的完整 ODE 采样流程 |
| Noise Schedule (Flow Matching) | causvid/models/wan/flow_match.py + causvid/scheduler.py | 噪声调度和 flow matching 相关实现 |
| Model Interface | causvid/models/model_interface.py | 统一模型接口(支持 SDXL 和 Wan 架构) |
| Data Loading | causvid/data.py | 视频数据加载和预处理 |
| Utilities | causvid/util.py | 通用工具函数 |
| AR Inference (Short Video) | minimal_inference/autoregressive_inference.py | 5 秒自回归推理脚本 |
| AR Inference (Long Video) | minimal_inference/longvideo_autoregressive_inference.py | 长视频流式推理(多 chunk rollout) |
| Bidirectional Inference | minimal_inference/bidirectional_inference.py | 双向模型推理(对比基线) |
| KV Caching Inference | causvid/models/wan/causal_inference.py | KV cache 管理和因果推理逻辑 |
| Config: Causal DMD | configs/wan_causal_dmd.yaml | 因果 DMD 蒸馏的超参数配置 |
| Config: Causal ODE | configs/wan_causal_ode.yaml | ODE 预训练的超参数配置 |
| Config: Bidirectional DMD | configs/wan_bidirectional_dmd.yaml | 双向 DMD 蒸馏配置(消融实验) |
训练 Pipeline 总览
1. 数据准备 (distillation_data/)
└→ 视频收集 + VAE latent 预计算
2. ODE 数据生成 (causvid/models/wan/generate_ode_pairs.py)
└→ 双向 teacher 运行 ODE solver 生成轨迹
3. ODE 回归预训练 (causvid/train_ode.py, configs/wan_causal_ode.yaml)
└→ 因果 student 在 ODE 对上做 MSE 回归, 3000 步
4. DMD 蒸馏 (causvid/train_distillation.py, configs/wan_causal_dmd.yaml)
└→ 非对称蒸馏: 双向 teacher score + 在线 gen score, 6000 步
5. 推理 (minimal_inference/autoregressive_inference.py)
└→ 4 步去噪 + KV caching 流式生成