CausVid: From Slow Bidirectional to Fast Autoregressive Video Diffusion Models

Authors: Tianwei Yin*, Qiang Zhang*, Richard Zhang, William T. Freeman, Fredo Durand, Eli Shechtman, Xun Huang Affiliations: MIT, Adobe Venue: CVPR 2025

1. Motivation

1.1 问题背景

当前主流的视频扩散模型(如 CogVideoX、MovieGen)都采用 Diffusion Transformer (DiT) 架构,使用 双向注意力 (bidirectional attention) 跨所有帧进行全序列去噪。这带来两个核心瓶颈:

  1. 高延迟:生成单帧需要处理整个序列(包括未来帧),用户必须等待全部帧生成完毕才能看到结果。例如生成一个 128 帧视频需要 219 秒。
  2. 无法流式生成:双向注意力使得当前帧依赖未来帧,不支持 streaming 和交互式应用(如实时 video-to-video、动态 prompting)。
  3. 计算量随帧数二次增长:attention 的计算和内存开销随帧数平方增长,使长视频生成极其昂贵。

1.2 现有方案的局限

  • 自回归视频扩散模型:虽然能流式生成,但 (a) 现有 autoregressive 扩散模型质量远不及双向模型,(b) 误差累积导致长视频质量快速退化,(c) 每步仍需 50 步去噪,速度依然不够。
  • 蒸馏方法:现有视频蒸馏(如 consistency distillation、adversarial distillation)大多 (a) 只处理短片段(<2 秒),(b) 将 non-causal teacher 蒸馏为 non-causal student,不改变流式生成的根本限制。

1.3 核心矛盾

如何同时获得 双向模型的高质量自回归模型的低延迟 + 流式生成能力?直接将双向模型转为因果模型并蒸馏效果很差——因果 teacher 本身就弱于双向 teacher,误差会传递给 student。

Figure 1 解读:顶部是传统双向扩散模型,延迟 219 秒生成 128 帧视频,用户需等待全部完成。底部是 CausVid 的因果 student 模型,初始延迟仅 1.3 秒即可开始输出第一批帧,之后以 9.4 FPS 流式持续生成。通过非对称蒸馏(Asymmetric Distillation with DMD),将双向 teacher 的知识迁移到因果 student,同时将 50 步去噪压缩为 4 步。


2. Idea

2.1 核心思路:非对称蒸馏

CausVid 的关键洞察是:不要从因果 teacher 蒸馏到因果 student,而是用双向 teacher 监督因果 student。这种”非对称”(asymmetric)蒸馏策略有两大优势:

  1. 双向 teacher 质量更高,能为因果 student 提供更强的监督信号
  2. DMD 损失在分布层面做匹配(而非逐样本),天然允许 teacher 和 student 有不同的架构(双向 vs 因果)

2.2 技术路线

整体方案分三步:

  1. 架构转换:将预训练双向 DiT 改为 block-wise causal attention(chunk 内双向,chunk 间因果),并用去噪损失微调
  2. ODE 初始化:用双向 teacher 生成 ODE 轨迹数据集,预训练因果 student 做回归,稳定后续蒸馏
  3. DMD 蒸馏:用双向 teacher 的 score function 作为 ,在线训练因果 student 的 score function ,通过 reverse KL 散度梯度更新 student

2.3 推理加速:KV Caching

因果架构天然支持 KV caching——已生成 chunk 的 key/value 可以缓存复用,新 chunk 只需计算自身的 attention,实现流式生成。

Figure 2 解读:CausVid 支持的多种视频生成任务。(1) Text-to-Video:纯文本驱动生成高质量视频;(2) Image-to-Video:给定首帧图像 + 文本,自回归扩展为视频(zero-shot,无需额外训练);(3) Video-to-Video:将游戏引擎渲染的简单输入实时转换为逼真视频(流式处理);(4) Dynamic Prompting:在视频生成过程中随时更换文本 prompt,实现剧情动态变化。所有任务均受益于低延迟的因果生成架构。


3. Method

3.1 自回归扩散 Transformer 架构

3D VAE 编码:视频首先通过 3D VAE 压缩到潜空间。VAE encoder 将每 16 帧视频压缩为 5 个 latent frames(一个 chunk),空间和时间维度同时压缩。

Block-wise Causal Attention:在标准 DiT 的 self-attention 层上施加 block-wise causal mask:

其中 是帧索引, 是 chunk size。chunk 内各帧双向可见(捕获局部时序依赖),chunk 间只能看到过去(确保因果性)。这样既保持了因果性用于流式推理,又保留了局部双向注意力提升质量。

Per-chunk 独立噪声时间步:遵循 Diffusion Forcing 的设计,每个 chunk 有独立采样的噪声时间步 ,支持不同 chunk 处于不同去噪阶段。

Figure 5 解读:自回归扩散 Transformer 架构示意图。三个 noisy frame chunk 输入 AR-DiT,注意力遵循 block-wise causal mask——当前 chunk 的 token 可以 attend 到自身和所有之前 chunk 的 token,但不能 attend 到未来 chunk。每个 chunk 有独立的噪声 level,模型在去噪过程中逐 chunk 自回归生成。

3.2 非对称蒸馏(Bidirectional Teacher → Causal Student)

为什么需要非对称? 直接微调双向 DiT 为因果模型后做蒸馏(causal teacher → causal student)效果较差:因果 teacher 质量本身不如双向模型,且误差会从 teacher 传递到 student 并在自回归推理中累积放大。

DMD 损失函数

其中:

  • :因果 student generator(4 步去噪)
  • 双向 teacher 的 score function(冻结)
  • :student 输出分布的 score function(在线训练)
  • :前向扩散加噪过程

关键点: 使用双向注意力, 使用因果注意力,DMD 在分布层面做匹配,允许这种架构不对称。

训练流程(Algorithm 1):

  1. 采样视频,分为 个 chunk
  2. 每个 chunk 独立采样噪声时间步,加噪
  3. Student 用因果 mask 预测 clean frames
  4. 加噪到随机时间步
  5. 用 DMD loss( vs )更新 student
  6. 用 student 输出训练 的 denoising loss

Figure 6 解读:训练流程两阶段。上方 Student Initialization:用双向 teacher 运行 ODE solver 生成 noise → clean video 的轨迹对,因果 student 在这些 ODE 对上做 MSE 回归预训练,学习基本的去噪映射。下方 Asymmetric Distillation with DMD:从数据集采样视频加噪后,因果 student 生成 clean 预测,然后重新加噪;双向 teacher 的 score function 提供”真实分布”方向的梯度,在线 score function 提供”生成分布”方向的梯度,二者之差驱动 student 向真实分布对齐。

3.3 Student 初始化(ODE Regression)

直接用 DMD loss 训练因果 student 不稳定(架构差异大)。解决方案:

  1. 用双向 teacher 的 ODE solver 生成轨迹 (从纯噪声 到干净数据
  2. 选取与 student 推理时间步匹配的子集
  3. Student 在这些 ODE 对上做回归预训练:

这一步只需约 1000 个 ODE 对,训练 3000 步即可收敛,为后续 DMD 蒸馏提供稳定的起点。

3.4 KV Caching 推理

推理时采用 KV caching 实现高效流式生成(Algorithm 2):

  1. 初始化 KV cache 为空
  2. 对每个 chunk
    • 从纯噪声 开始
    • 执行 4 步迭代去噪,每步利用 cache 中之前 chunk 的 KV pairs
    • 去噪完成后,对 clean 结果 做一次前向传播计算 KV pairs 并追加到 cache
  3. 返回所有 chunk 的 clean 结果

由于使用 KV caching,每步去噪只需计算当前 chunk 的 attention(不需要重新计算所有历史帧),实现了 9.4 FPS 的流式生成速度。值得注意的是,在推理阶段可以使用标准的 FlashAttention 加速(因为因果性通过 KV cache 隐式保证,不需要显式 causal mask)。


4. Experimental Setup (实验设置)

4.1 实验设置

  • Teacher 模型:双向 DiT(架构类似 CogVideoX),50 步去噪,训练在 352x640、12 FPS、10 秒视频上
  • Student 模型:同一架构 + block-wise causal attention mask,4 步去噪(时间步 [999, 748, 502, 247])
  • 训练数据:约 400K 单镜头视频(含图像和视频混合数据集),经安全和美学过滤
  • 训练流程:ODE 初始化 3000 步 → DMD 蒸馏 6000 步,AdamW 优化器,64 张 H100 GPU 约 2 天
  • 评估指标:VBench(16 项指标:temporal quality, frame quality, text alignment)

5. Experimental Results (实验结果)

5.1 Text-to-Short-Video(5-10 秒)

MethodLength(s)Temporal QualityFrame QualityText Alignment
CogVideoX-5B689.959.829.1
OpenSORA888.452.028.4
Pyramid Flow1089.655.927.1
MovieGen1091.561.128.8
CausVid (Ours)1094.764.430.1

CausVid 在三个核心维度上全面超越所有基线,temporal quality 达到 94.7(运动一致性和动态质量最优)。

5.2 Text-to-Long-Video(~30 秒)

MethodTemporal QualityFrame QualityText Alignment
Gen-L-Video86.752.328.7
FreeNoise86.254.828.7
StreamingT2V89.246.127.2
FIFO-Diffusion93.157.929.9
Pyramid Flow89.048.324.4
CausVid (Ours)94.963.428.9

长视频生成中 CausVid 同样以大幅优势领先,尤其在 temporal quality 和 frame quality 上。VBench-Long 总分达到 84.27,排名第一。

5.3 延迟与吞吐量

MethodLatency (s)Throughput (FPS)
CogVideoX-5B208.60.6
Pyramid Flow6.72.5
Bidirectional Teacher219.20.6
CausVid (Ours)1.39.4

CausVid 实现 160x 延迟降低16x 吞吐量提升(相比同架构双向 teacher)。初始延迟仅 1.3 秒即可开始流式输出。

5.4 消融实验

Figure 8 解读:30 秒视频生成过程中各方法的帧质量随时间变化曲线。橙色是因果 teacher(直接微调双向模型为因果),质量快速退化(误差累积)。绿色是”CausVid with Causal Teacher”(因果 teacher 蒸馏因果 student),退化稍慢但仍显著。蓝色是完整 CausVid(双向 teacher 蒸馏因果 student),质量在 30 秒内保持稳定,几乎无退化。FIFO-Diffusion(紫色)也能保持质量但延迟高得多。这证明了非对称蒸馏是抑制误差累积的关键。

关键消融结论(Table 4):

ODE Init.TeacherTemporal QualityFrame QualityText Alignment
Bidirectional93.460.629.4
None92.948.125.3
Causal91.961.728.2
Bidirectional94.764.430.1
  • 双向 teacher > 因果 teacher > 无 teacher:双向 teacher 的监督最有效
  • ODE 初始化有帮助:提升训练稳定性和最终质量
  • 组合效果最优:ODE 初始化 + 双向 teacher DMD 蒸馏

5.5 应用

Streaming Video-to-Video Translation

MethodTemporal QualityFrame QualityText Alignment
StreamV2V92.559.326.9
CausVid (Ours)93.261.727.7

通过 SDEdit 机制(对输入视频 chunk 加噪到 再一步去噪),实现流式视频风格转换。

Image-to-Video(Zero-shot):将输入图像复制为第一个 chunk 的所有帧,后续 chunk 自回归生成。在 VBench-I2V 上 temporal quality 达 92.0,frame quality 65.0,超越 CogVideoX-5B 和 Pyramid Flow。

Dynamic Prompting:生成过程中可随时更换文本 prompt,实现场景和动作的动态切换。

Figure 7 解读:用户偏好研究结果。CausVid 在与 MovieGen(60.9% vs 39.1%)、CogVideoX(69.0% vs 31.0%)、Pyramid Flow(61.4% vs 17.6%,剩余持平)的对比中均获得多数偏好。与自己的双向 teacher 对比(51.7% vs 48.3%)几乎持平,说明蒸馏后的 4 步因果 student 质量逼近 50 步双向 teacher。


6. Pseudocode

5.1 Block-wise Causal Attention Mask

import torch
 
def create_blockwise_causal_mask(num_frames: int, chunk_size: int) -> torch.Tensor:
    """
    创建 block-wise causal attention mask.
    chunk 内双向可见,chunk 间只能看到过去。
 
    Args:
        num_frames: 总帧数 N
        chunk_size: 每个 chunk 的帧数 k
    Returns:
        mask: (N, N) 二值矩阵, 1=可见, 0=遮挡
    """
    mask = torch.zeros(num_frames, num_frames)
    for i in range(num_frames):
        for j in range(num_frames):
            if j // chunk_size <= i // chunk_size:
                mask[i, j] = 1.0
    return mask  # (N, N)

5.2 Asymmetric DMD Training

import torch
import torch.nn.functional as F
 
def asymmetric_dmd_training_step(
    student_G: "CausalDiT",           # 因果 student generator (4-step)
    s_data: "BidirectionalDiT",        # 双向 teacher score function (frozen)
    s_gen: "CausalDiT",               # generator 分布的 score function (online)
    video_batch: torch.Tensor,          # (B, N, C, H, W) clean video
    chunk_size: int,
    noise_schedule: "NoiseSchedule",
    timesteps_student: list,            # e.g., [999, 748, 502, 247]
):
    """
    一步非对称 DMD 蒸馏训练。
 
    核心:用双向 teacher 的 s_data 提供"真实分布"梯度,
    用在线训练的 s_gen 提供"生成分布"梯度,
    二者之差驱动因果 student 向真实分布对齐。
    """
    B, N, C, H, W = video_batch.shape
    L = N // chunk_size  # chunk 数量
 
    # --- Step 1: Student 生成 ---
    # 每个 chunk 独立采样噪声时间步
    t_per_chunk = torch.randint(0, len(timesteps_student), (B, L))  # (B, L)
    noise = torch.randn_like(video_batch)  # (B, N, C, H, W)
 
    # 对每个 chunk 加对应时间步的噪声
    noisy_video = add_noise_per_chunk(video_batch, noise, t_per_chunk,
                                       chunk_size, noise_schedule)
 
    # Student 用 block-wise causal mask 预测 clean video
    x0_pred = student_G(noisy_video, t_per_chunk)  # (B, N, C, H, W)
 
    # --- Step 2: DMD Loss 计算 ---
    # 对预测结果加随机噪声(用于 score function 评估)
    t_dmd = torch.randint(0, 1000, (B,))  # 随机采样单一时间步
    noise_dmd = torch.randn_like(x0_pred)
    x_t = noise_schedule.add_noise(x0_pred, noise_dmd, t_dmd)  # (B, N, C, H, W)
 
    # 双向 teacher score(frozen):真实数据分布的梯度方向
    with torch.no_grad():
        score_data = s_data(x_t, t_dmd)  # 双向 attention, 全帧可见
 
    # Generator score(online):当前生成分布的梯度方向
    score_gen = s_gen(x_t, t_dmd)  # 因果 attention
 
    # DMD gradient: score_data - score_gen 驱动 student 更新
    # 实际实现中通过 score difference 作为伪梯度反传
    dmd_loss = compute_dmd_gradient(x0_pred, score_data, score_gen, t_dmd)
 
    # --- Step 3: 更新 s_gen 的 denoising loss ---
    # 用 student 的输出作为"数据"训练 s_gen
    with torch.no_grad():
        x0_detached = x0_pred.detach()
    noise_new = torch.randn_like(x0_detached)
    t_new = torch.randint(0, 1000, (B,))
    x_t_new = noise_schedule.add_noise(x0_detached, noise_new, t_new)
 
    eps_pred = s_gen(x_t_new, t_new)
    s_gen_loss = F.mse_loss(eps_pred, noise_new)  # 标准 denoising loss
 
    return dmd_loss, s_gen_loss

5.3 KV Caching Inference

import torch
 
def autoregressive_inference_with_kv_cache(
    student_G: "CausalDiT",
    prompt_embedding: torch.Tensor,
    num_chunks: int,
    chunk_size: int,
    timesteps: list = [999, 748, 502, 247],  # 4 步去噪
):
    """
    Algorithm 2: 使用 KV Caching 的流式自回归推理。
 
    核心:每个 chunk 去噪完成后,计算其 KV pairs 并缓存;
    后续 chunk 去噪时复用历史 KV,避免重复计算。
    """
    kv_cache = None  # 初始化为空
    all_chunks = []
 
    for i in range(num_chunks):
        # 从纯噪声开始
        x_i = torch.randn(1, chunk_size, C, H, W)  # 当前 chunk
 
        # 4 步迭代去噪
        for j in range(len(timesteps) - 1, -1, -1):
            t_j = timesteps[j]
            # 使用 KV cache 中之前 chunk 的信息
            x0_pred = student_G.forward_with_cache(x_i, t_j, kv_cache)
 
            if j > 0:
                # DDIM-style 更新到下一个时间步
                t_prev = timesteps[j - 1]
                x_i = ddim_step(x0_pred, x_i, t_j, t_prev)
            else:
                x_i = x0_pred  # 最后一步直接输出
 
        # 去噪完成,更新 KV cache
        # 对 clean chunk 做一次 t=0 的前向传播,获取 KV pairs
        with torch.no_grad():
            new_kv = student_G.compute_kv(x_i, t=0)
        kv_cache = append_to_cache(kv_cache, new_kv)
 
        all_chunks.append(x_i)
        # 此时可以立即输出/显示当前 chunk(流式)
 
    return torch.cat(all_chunks, dim=1)  # (1, N, C, H, W)

7. Code Mapping

GitHub 仓库: tianweiy/CausVid

论文概念代码位置说明
Block-wise Causal DiTcausvid/models/wan/causal_model.py因果注意力 mask 实现,chunk 内双向、chunk 间因果
Bidirectional DiT Wrappercausvid/models/wan/wan_wrapper.py双向 teacher 模型封装
DMD Losscausvid/dmd.py + causvid/loss.pyDistribution Matching Distillation 损失计算,包含 score function 差值梯度
Asymmetric Distillation Trainingcausvid/train_distillation.py主训练循环:因果 student + 双向 teacher 的非对称蒸馏
ODE Trajectory Generationcausvid/models/wan/generate_ode_pairs.py用双向 teacher 的 ODE solver 生成初始化数据
ODE Regression (Student Init)causvid/train_ode.py + causvid/ode_regression.pyStudent 在 ODE 对上的预训练回归
ODE Data Pipelinecausvid/ode_data/create_lmdb_iterative.pyODE 轨迹数据的 LMDB 存储和读取
Bidirectional Trajectory Pipelinecausvid/bidirectional_trajectory_pipeline.py双向模型的完整 ODE 采样流程
Noise Schedule (Flow Matching)causvid/models/wan/flow_match.py + causvid/scheduler.py噪声调度和 flow matching 相关实现
Model Interfacecausvid/models/model_interface.py统一模型接口(支持 SDXL 和 Wan 架构)
Data Loadingcausvid/data.py视频数据加载和预处理
Utilitiescausvid/util.py通用工具函数
AR Inference (Short Video)minimal_inference/autoregressive_inference.py5 秒自回归推理脚本
AR Inference (Long Video)minimal_inference/longvideo_autoregressive_inference.py长视频流式推理(多 chunk rollout)
Bidirectional Inferenceminimal_inference/bidirectional_inference.py双向模型推理(对比基线)
KV Caching Inferencecausvid/models/wan/causal_inference.pyKV cache 管理和因果推理逻辑
Config: Causal DMDconfigs/wan_causal_dmd.yaml因果 DMD 蒸馏的超参数配置
Config: Causal ODEconfigs/wan_causal_ode.yamlODE 预训练的超参数配置
Config: Bidirectional DMDconfigs/wan_bidirectional_dmd.yaml双向 DMD 蒸馏配置(消融实验)

训练 Pipeline 总览

1. 数据准备 (distillation_data/)
   └→ 视频收集 + VAE latent 预计算
2. ODE 数据生成 (causvid/models/wan/generate_ode_pairs.py)
   └→ 双向 teacher 运行 ODE solver 生成轨迹
3. ODE 回归预训练 (causvid/train_ode.py, configs/wan_causal_ode.yaml)
   └→ 因果 student 在 ODE 对上做 MSE 回归, 3000 步
4. DMD 蒸馏 (causvid/train_distillation.py, configs/wan_causal_dmd.yaml)
   └→ 非对称蒸馏: 双向 teacher score + 在线 gen score, 6000 步
5. 推理 (minimal_inference/autoregressive_inference.py)
   └→ 4 步去噪 + KV caching 流式生成