Autoregressive Adversarial Post-Training for Real-Time Interactive Video Generation

Authors: Shanchuan Lin, Ceyuan Yang, Hao He, Jianwen Jiang, Yuxi Ren, Xin Xia, Yang Zhao, Xuefeng Xiao, Lu Jiang Affiliations: ByteDance Seed arXiv: 2506.09350 Venue: NeurIPS 2025 Code: 未开源

1. Motivation (研究动机)

1.1 现有方法的瓶颈

当前大规模视频生成模型(Diffusion Models)虽然质量优秀,但推理开销巨大,无法满足实时交互式视频生成的需求。核心挑战有三个:

  1. 吞吐量不足:Diffusion model 需要多步迭代去噪(e.g. 4-60 步),即使经过蒸馏到 4 步(如 CausVid),在单 GPU 上仅能达到 9.4fps(640x320),远低于实时 24fps 要求。
  2. 延迟过高:Diffusion forcing 方法每步需要处理多帧上下文,即使使用 KV cache,每个 autoregressive step 仍需计算两帧(当前帧 + 上一帧),计算冗余大。
  3. 长视频生成困难:现有 diffusion-forcing 模型训练时通常只支持固定窗口(e.g. 5 秒),超出训练长度后 error accumulation 严重导致画面崩溃;扩展生成需要 restart 和 re-compute,不适合流式应用。

1.2 为什么探索 Adversarial Training

  • Token-based autoregressive 方法(如 VideoPoet)可以自然利用 KV cache,但逐 token 解码太慢,且离散化会损失质量。
  • Adversarial post-training (APT) 已在图像和短视频领域证明可以将 diffusion model 蒸馏为 one-step generator,但此前仅用于非自回归场景。
  • 本文首次将 adversarial training 扩展到 autoregressive 视频生成场景,结合了两者的优势:一步生成的速度 + 自回归的流式能力。

2. Idea (核心思想)

核心思想:将预训练的 bidirectional video diffusion transformer 通过 autoregressive adversarial post-training (AAPT) 转化为一个逐帧(per-latent-frame)的 one-step causal generator。

关键创新点

  1. Causal Architecture + Input Recycling:将 bidirectional DiT 改造为 block causal attention 架构,每步只需生成一个 latent frame(对应 4 个视频帧),上一步的输出通过 channel concatenation 回收为下一步的输入,配合 KV cache 实现比 diffusion forcing 高 2x 的效率。

  2. Student-Forcing Adversarial Training:不同于 teacher-forcing(输入 ground-truth 帧),在对抗训练阶段采用 student-forcing——generator 仅以第一帧为 ground-truth,之后回收自己的生成结果作为输入,完全模拟推理时的行为。这有效减少了 error accumulation。

  3. Long-Video Training Strategy:通过让 generator 生成长视频(e.g. 60 秒),将其切分为短片段(e.g. 10 秒)进行 discriminator 评估,利用 adversarial objective 不需要逐帧 ground-truth 配对的特性,绕过长视频训练数据稀缺的问题。


3. Method (方法)

3.1 整体架构

Figure 1 解读:左侧为 Generator,右侧为 Discriminator。Generator 是一个 block causal transformer:输入包括 text embedding、noise(随机采样)、condition(交互控制信号)和上一步生成的 frame(recycled input)。通过 block causal attention,text tokens 只 attend to 自身,visual tokens attend to 之前和当前帧的 tokens。每步输出一个 latent frame(4 个视频帧)。第一步的 Frame 0 由用户提供(I2V 设定)。生成的帧被回收为下一步输入,配合 KV cache 避免重复计算。Discriminator 使用相同的 block causal architecture,但将 noise channel 替换为 frame input,condition 输入做 shift 对齐,对每帧独立输出 logit,实现 parallel multi-duration discrimination。

Figure 2 解读:对比 one-step diffusion forcing (DF) 与本文方法的计算效率。DF 在使用 KV cache 时,每个 AR step 仍需对两帧进行 compute(当前帧 + 上一帧作为 condition),而本文方法每步只需 compute 一帧,因为上一帧的信息已通过 channel concatenation 回收到输入中,KV cache 保存了历史信息。这使得本文方法比 DF 快约 2 倍。

3.2 Causal Architecture 详细设计

Block Causal Attention

  • 将原始 bidirectional full attention 替换为 block causal attention
  • Text tokens 只 attend to text tokens
  • Visual tokens attend to 当前帧和之前所有帧的 tokens(受 sliding window 限制)
  • Sliding window size N=30 latent frames,always attend to text tokens 和第一帧

Input Recycling

  • 每步输入:noise z ~ N(0,I) + condition c_t + recycled frame x_{t-1}(通过 channel concatenation)
  • 第一步:recycled frame 为用户提供的 Frame 0
  • 之后:recycled frame 为上一步 generator 的输出

Positional Embedding

  • 使用 3D RoPE(Rotary Position Embedding)
  • Spatial 维度动态 stretch 支持多分辨率
  • Temporal 维度固定间隔,支持任意长度生成
# Pseudocode: Autoregressive Inference with KV Cache
def aapt_inference(first_frame, conditions, text, num_steps):
    """
    first_frame: 用户提供的初始帧 (latent)
    conditions: list of per-frame conditions (pose, camera, etc.)
    text: text embedding
    num_steps: 生成的 latent frame 数量
    """
    kv_cache = None
    recycled_frame = first_frame  # Frame 0 由用户提供
    generated_frames = [first_frame]
 
    for t in range(num_steps):
        noise = torch.randn_like(recycled_frame)  # 采样随机噪声
 
        # Channel concatenation: [noise, condition, recycled_frame]
        input_t = channel_concat(noise, conditions[t], recycled_frame)
 
        # Single forward pass with KV cache
        output_frame, kv_cache = generator(
            input_t, text, kv_cache,
            use_block_causal_attention=True,
            sliding_window=30
        )
 
        generated_frames.append(output_frame)
        recycled_frame = output_frame.detach()  # 回收生成结果
 
    # Decode all latent frames through causal 3D VAE
    video = vae_decode(generated_frames)  # 每个 latent frame -> 4 video frames
    return video

3.3 三阶段训练流程

训练分为三个顺序阶段:

Stage 1: Diffusion Adaptation

将预训练 bidirectional DiT 适配为 causal autoregressive 架构:

  • 使用 flow-matching parameterization:
  • Timestep 经过 shifting function:,其中
  • Teacher-forcing:输入 ground-truth frames 作为 recycled input,output target shift by one frame
  • 所有帧使用相同的 timestep(不同于 diffusion forcing 的 progressive noise)
  • Loss: MSE on predicted velocity
# Pseudocode: Diffusion Adaptation (Stage 1)
def diffusion_adaptation_step(model, video_frames, text, conditions):
    """Teacher-forcing diffusion training for causal architecture adaptation"""
    B, T = video_frames.shape[:2]  # batch, num_latent_frames
 
    # 采样统一 timestep(所有帧共享)
    t = torch.rand(B) * 1.0  # t ~ U(0, 1)
    t = shift(t, s=24)       # shifting function
 
    # 构造 noisy input
    epsilon = torch.randn_like(video_frames)
    x_t = (1 - t) * video_frames + t * epsilon  # flow-matching interpolation
 
    # Teacher-forcing: recycled input = ground-truth previous frames
    recycled_inputs = video_frames[:, :-1]  # frames 0 to T-2
    noisy_inputs = x_t[:, 1:]               # noisy frames 1 to T-1
    targets = video_frames[:, 1:]           # target frames 1 to T-1
 
    # Forward pass with block causal attention (parallel, all steps at once)
    pred_velocity = model(noisy_inputs, recycled_inputs, text, conditions)
 
    loss = F.mse_loss(pred_velocity, epsilon - targets)
    return loss

训练细节:AdamW optimizer,lr=1e-5,weight decay=0.01。先在 736x416 训练 20k iter(batch=256),再加入 1280x720 训练 6k iter(batch=128),最后提升最长时长到 15 秒训练 4k iter(batch=32)。

Stage 2: Consistency Distillation

将多步 diffusion model 蒸馏为 one-step generator:

  • 在 32 个 fixed steps 上进行 consistency distillation
  • 继续使用 teacher-forcing
  • 不使用 classifier-free guidance(CFG 在自回归设定下会产生 artifacts)
  • 不使用 EMA moving average
  • 训练 5k iterations
# Pseudocode: Consistency Distillation (Stage 2)
def consistency_distillation_step(student, teacher, video_frames, text, conditions):
    """Distill multi-step diffusion to one-step generation"""
    # 采样相邻的 timestep pair (t_n, t_{n+1}) from 32 fixed steps
    t_n, t_n1 = sample_adjacent_timesteps(num_steps=32, s=24)
 
    epsilon = torch.randn_like(video_frames[:, 1:])
    x_tn1 = (1 - t_n1) * video_frames[:, 1:] + t_n1 * epsilon
 
    # Teacher: one ODE step from t_{n+1} to t_n
    with torch.no_grad():
        x_tn = ode_step(teacher, x_tn1, t_n1, t_n, video_frames[:, :-1], text)
 
    # Student: predict clean frame from both noise levels
    pred_from_tn1 = student.predict_clean(x_tn1, t_n1, video_frames[:, :-1], text)
    pred_from_tn = student.predict_clean(x_tn, t_n, video_frames[:, :-1], text)
 
    loss = F.mse_loss(pred_from_tn1, pred_from_tn.detach())
    return loss

Stage 3: Adversarial Training

核心阶段,使用 student-forcing + adversarial objective:

  • Generator:从 consistency distillation weights 初始化
  • Discriminator:从 diffusion adaptation weights 初始化,相同 causal architecture,noise channel 替换为 frame input,添加 per-frame logit output projection
  • Student-Forcing:generator 仅以第一帧为 GT,之后回收自己的生成结果
  • Loss:R3GAN (relativistic pairing loss) + R1/R2 regularization

其中 generator 用 ,discriminator 用

R1/R2 regularization:

其中

# Pseudocode: Student-Forcing Adversarial Training (Stage 3)
def adversarial_training_step(generator, discriminator, first_frame, real_video, text, conditions):
    """
    Student-forcing: generator recycles its own output, mimicking inference.
    Discriminator evaluates in parallel with per-frame logits.
    """
    num_frames = real_video.shape[1]
 
    # === Generator: Student-Forcing Autoregressive Generation ===
    generated_frames = []
    recycled = first_frame  # Only GT input
    kv_cache_g = None
 
    for t in range(num_frames):
        noise = torch.randn(...)
        input_t = channel_concat(noise, conditions[t], recycled.detach())  # detach recycled frame
        frame_t, kv_cache_g = generator(input_t, text, kv_cache_g)
        generated_frames.append(frame_t)
        recycled = frame_t  # Recycle own output (student-forcing)
 
    fake_video = torch.stack(generated_frames, dim=1)
 
    # === Discriminator: Parallel evaluation, per-frame logits ===
    # Discriminator uses block causal attention, processes all frames at once
    fake_logits = discriminator(fake_video, text, conditions)  # [B, T] per-frame logits
    real_logits = discriminator(real_video, text, conditions)   # [B, T] per-frame logits
 
    # R3GAN relativistic pairing loss
    loss_G = -torch.mean(torch.log(torch.sigmoid(fake_logits - real_logits)))
    loss_D = -torch.mean(torch.log(torch.sigmoid(real_logits - fake_logits)))
 
    # R1 & R2 regularization (approximated)
    loss_R1 = approx_r1(discriminator, real_video, sigma=0.1, lam=1000)
    loss_R2 = approx_r2(discriminator, fake_video, sigma=0.1, lam=1000)
 
    return loss_G, loss_D + loss_R1 + loss_R2

3.4 Long-Video Training

解决长视频训练数据稀缺问题的关键策略:

  1. Generator 生成长视频(e.g. 60 秒)
  2. 切分为短片段(e.g. 10 秒),保留 1 秒 overlap 以鼓励片段连续性
  3. Discriminator 对每个片段独立评估(与真实视频片段对比)
  4. Generator 每次只生成一个片段供 discriminator 评估,通过 detached KV cache 连接片段
  5. 每个片段评估后即反向传播,累积梯度
# Pseudocode: Long-Video Training Strategy
def long_video_training(generator, discriminator, first_frame, text, conditions,
                        total_duration=60, segment_duration=10, overlap=1):
    """Train on long videos by segment-wise discriminator evaluation"""
    total_loss_G, total_loss_D = 0, 0
    kv_cache_g = None
    recycled = first_frame
 
    segments = split_into_segments(total_duration, segment_duration, overlap)
 
    for seg_idx, (start, end) in enumerate(segments):
        # Generator produces one segment autoregressively
        segment_frames = []
        for t in range(start, end):
            noise = torch.randn(...)
            input_t = channel_concat(noise, conditions[t], recycled.detach())
            frame_t, kv_cache_g = generator(input_t, text, kv_cache_g)
            segment_frames.append(frame_t)
            recycled = frame_t
 
        fake_segment = torch.stack(segment_frames, dim=1)
 
        # Sample a real segment of same duration from dataset
        real_segment = sample_real_segment(segment_duration)
 
        # Discriminator evaluates this segment
        loss_G, loss_D = adversarial_loss(discriminator, fake_segment, real_segment, text)
        loss_G.backward()  # Backprop per segment for memory efficiency
        loss_D.backward()
 
        total_loss_G += loss_G.item()
        total_loss_D += loss_D.item()
 
        # Detach KV cache for next segment (cut gradient graph)
        kv_cache_g = detach_kv_cache(kv_cache_g)
 
    optimizer_G.step()
    optimizer_D.step()

训练 schedule

  • 先不使用 long-video extension 训练 500 updates(5-10s 视频,lr=3e-6,batch=256)
  • 加入 long-video extension(overlap 1s,max 19s),训练 500 updates
  • 扩展 5 次(max 55s),降低 batch 到 64,lr 降至 1e-5
  • 总计约 256 H100 GPUs 训练约 7 天

3.5 Teacher-Forcing vs Student-Forcing

Figure 8 (Teacher-Forcing Adversarial Training) 解读:展示 teacher-forcing 模式下的对抗训练。Generator 输入 ground-truth frames I1, I2, I3,独立生成 O2, O3, O4。注意 O3 只与 I2 有关联,与 O2 无关(因为输入是 GT 而非自身生成结果)。Discriminator 需要独立评估每个生成结果及其正确的依赖关系。这种模式下训练与推理存在 distribution gap——训练时输入是干净的 GT frame,推理时输入是带误差的生成 frame——导致 error accumulation 在推理时快速恶化。

3.6 Interactive Applications

本文在两个交互式应用场景进行了验证:

  1. Pose-Conditioned Virtual Human Generation:从训练视频中提取人体姿态,编码为 per-frame condition 输入模型。
  2. Camera-Controlled World Exploration:使用修改后的 Plucker embeddings 编码相机位姿,支持用户实时控制相机探索场景。对 CameraCtrl II 做了多处改进:相对位移(frame-to-frame 而非 relative-to-first)、直接编码 camera ray origin+direction、输入缩放到 1 std、随机初始化新 channel projection。

4. Experimental Setup (实验设置)

4.1 模型配置

配置项设置
BackboneMMDiT (8B parameters, 36 transformer blocks)
Discriminator同 Generator 架构 (8B),总训练参数 16B
VAECausal 3D convolution VAE,temporal 压缩 4x,spatial 压缩 8x
Latent frame1 latent frame = 4 video frames
Attention windowN=30 latent frames (always attend to text + first frame)
Block causal attentionFlash Attention 3, for-loop 实现
Positional encoding3D RoPE (spatial dynamic stretch, temporal fixed interval)
ParallelismFSDP (ZERO 2 for generator, ZERO 3 for discriminator) + Ulysses context parallel

4.2 训练配置

阶段OptimizerLRIterationsBatch Size数据
Diffusion AdaptationAdamW1e-520k+6k+4k256/128/325s clips (736x416 +1280x720 15s)
Consistency DistillationAdamW1e-55k同上同上
Adversarial TrainingRMSProp (D)3e-6 (G), APT setting (D)~1500+256645-10s 55s

4.3 评估基准

  • VBench-I2V:120-frame (短视频) 和 1440-frame (1分钟长视频) 两个设定
  • Pose-Conditioned Human Video:AKD (Average Keypoint Distance), IQA, ASE, FID, FVD
  • Camera-Conditioned World Exploration:FVD, Mov, Trans, Rot, Geo, Apr
  • Baselines:CausVid, Wan2.1, Hunyuan, MAGI-1, SkyReel-V2, 及自身 diffusion baseline

4.4 轻量化 VAE Decoder

为适配实时推理,训练了轻量化 VAE decoder:

  • Residual blocks per resolution: 3 2
  • Channels: [128, 256, 512, 512] [64, 128, 256, 512]
  • 速度提升约 3x,无明显质量下降

5. Experimental Results (实验结果)

5.1 主要定量结果

Table 1: VBench-I2V 定量对比 (736x416)

FramesMethodTemporal QualityFrame QualitySubject Consist.Background Consist.Motion Smooth.Dynamic DegreeAestheticImagingI2V SubjectI2V Background
120CausVid92.0065.00--------
120Wan 2.187.9566.5893.8596.5997.8239.1163.5669.5996.8298.57
120Hunyuan89.8064.1893.0698.5398.8054.8060.5867.7897.7197.97
120Ours (Diffusion)90.4066.0894.5896.7698.8052.5262.4469.7197.8999.14
120Ours (AAPT)89.5166.5896.2296.6699.1942.4462.0971.0698.6099.36
1440SkyReel-V282.1953.6778.4386.3899.2847.1553.6853.6596.5098.07
1440MAGI-180.7960.0182.2389.2798.5425.4552.2667.7596.9098.13
1440Ours (Diffusion)86.6560.4982.3889.4898.2966.2656.4664.5195.0197.72
1440Ours (AAPT)89.7962.1687.1589.7499.1176.5056.7767.5596.1197.52

关键发现

  • 120-frame 生成:AAPT 在 frame quality、conditioning scores 方面全面领先,但 temporal quality 略有下降(adversarial training 的 trade-off)
  • 1440-frame 生成:AAPT 在 quality 指标上全面最优,显著优于 SkyReel-V2 和 MAGI-1
  • CausVid temporal quality 高是因为其训练在 12fps 数据上,dynamic degree 更高

Table 4: 延迟与吞吐量对比

MethodParamsH100ResolutionNFELatencyFPS
CausVid5B1x640x35241.30s9.4
Ours8B1x736x41610.16s24.8
MAGI-124B8x-87.00s3.43
SkyReelV214B8x960x544604.50s0.89
Ours8B8x1280x72010.17s24.2

核心优势:单 H100 上 736x416@24fps,延迟仅 0.16 秒;8xH100 上 1280x720@24fps,延迟 0.17 秒。比 CausVid 快约 8 倍(延迟),FPS 高 2.6 倍。

5.2 定性对比

Figure 3 解读:1440-frame(1 分钟)VBench-I2V 生成的定性对比。(a) SkyReel-V2 在 20-30 秒后出现明显 error accumulation,人物面部和背景严重退化。(b) MAGI-1 同样在中后期画面崩溃。(c) 自身 diffusion baseline(使用 extension 方式)也出现退化。(d) AAPT 但不使用 long-video training,在 10 秒后开始 drift(说明 long-video training 的必要性)。(e) 完整 AAPT 模型在整个 60 秒内保持稳定的画面质量。

5.3 交互式应用结果

Pose-Conditioned Virtual Human (Table 2)

  • 在 6 种方法中,pose accuracy (AKD) 排名第二(仅次于 OmniHuman-1)
  • 视觉质量 (IQA, ASE) 排名第二或第三,与 CyberHost 接近

Camera-Conditioned World Exploration (Table 3)

  • 在 6 个指标中 3 个达到 SOTA(FVD, Trans, Geo)
  • 总体与 CameraCtrl2 接近,但 FVD 大幅领先(61.33 vs 73.11)

5.4 Ablation Studies

Long-Video Training Duration (Table 5)

Training DurationTemporal QualityFrame Quality
10s85.8657.92
20s85.6065.69
60s89.7962.16

训练 60 秒的模型显著优于仅训练 10 秒的模型,验证了 long-video training 的有效性。

Teacher-Forcing vs Student-Forcing:Teacher-forcing adversarial training 的模型在推理时几帧后就开始 drift,无法生成有意义的内容。Student-forcing 是减少 error accumulation 的关键。

5.5 Limitations

  1. Subject/Scene 一致性:长视频中可能出现主体和场景的变化,源于 generator 和 discriminator 都使用 basic sliding window
  2. Defect 持续性:one-step 生成偶尔产生缺陷,一旦出现会因 temporal consistency 持续存在
  3. 训练成本:long-video adversarial training 过程较慢
  4. 5 分钟以上:zero-shot 外推到 5 分钟可以生成内容,但有 artifacts