Autoregressive Adversarial Post-Training for Real-Time Interactive Video Generation
Authors: Shanchuan Lin, Ceyuan Yang, Hao He, Jianwen Jiang, Yuxi Ren, Xin Xia, Yang Zhao, Xuefeng Xiao, Lu Jiang Affiliations: ByteDance Seed arXiv: 2506.09350 Venue: NeurIPS 2025 Code: 未开源
1. Motivation (研究动机)
1.1 现有方法的瓶颈
当前大规模视频生成模型(Diffusion Models)虽然质量优秀,但推理开销巨大,无法满足实时交互式视频生成的需求。核心挑战有三个:
- 吞吐量不足:Diffusion model 需要多步迭代去噪(e.g. 4-60 步),即使经过蒸馏到 4 步(如 CausVid),在单 GPU 上仅能达到 9.4fps(640x320),远低于实时 24fps 要求。
- 延迟过高:Diffusion forcing 方法每步需要处理多帧上下文,即使使用 KV cache,每个 autoregressive step 仍需计算两帧(当前帧 + 上一帧),计算冗余大。
- 长视频生成困难:现有 diffusion-forcing 模型训练时通常只支持固定窗口(e.g. 5 秒),超出训练长度后 error accumulation 严重导致画面崩溃;扩展生成需要 restart 和 re-compute,不适合流式应用。
1.2 为什么探索 Adversarial Training
- Token-based autoregressive 方法(如 VideoPoet)可以自然利用 KV cache,但逐 token 解码太慢,且离散化会损失质量。
- Adversarial post-training (APT) 已在图像和短视频领域证明可以将 diffusion model 蒸馏为 one-step generator,但此前仅用于非自回归场景。
- 本文首次将 adversarial training 扩展到 autoregressive 视频生成场景,结合了两者的优势:一步生成的速度 + 自回归的流式能力。
2. Idea (核心思想)
核心思想:将预训练的 bidirectional video diffusion transformer 通过 autoregressive adversarial post-training (AAPT) 转化为一个逐帧(per-latent-frame)的 one-step causal generator。
关键创新点:
-
Causal Architecture + Input Recycling:将 bidirectional DiT 改造为 block causal attention 架构,每步只需生成一个 latent frame(对应 4 个视频帧),上一步的输出通过 channel concatenation 回收为下一步的输入,配合 KV cache 实现比 diffusion forcing 高 2x 的效率。
-
Student-Forcing Adversarial Training:不同于 teacher-forcing(输入 ground-truth 帧),在对抗训练阶段采用 student-forcing——generator 仅以第一帧为 ground-truth,之后回收自己的生成结果作为输入,完全模拟推理时的行为。这有效减少了 error accumulation。
-
Long-Video Training Strategy:通过让 generator 生成长视频(e.g. 60 秒),将其切分为短片段(e.g. 10 秒)进行 discriminator 评估,利用 adversarial objective 不需要逐帧 ground-truth 配对的特性,绕过长视频训练数据稀缺的问题。
3. Method (方法)
3.1 整体架构
Figure 1 解读:左侧为 Generator,右侧为 Discriminator。Generator 是一个 block causal transformer:输入包括 text embedding、noise(随机采样)、condition(交互控制信号)和上一步生成的 frame(recycled input)。通过 block causal attention,text tokens 只 attend to 自身,visual tokens attend to 之前和当前帧的 tokens。每步输出一个 latent frame(4 个视频帧)。第一步的 Frame 0 由用户提供(I2V 设定)。生成的帧被回收为下一步输入,配合 KV cache 避免重复计算。Discriminator 使用相同的 block causal architecture,但将 noise channel 替换为 frame input,condition 输入做 shift 对齐,对每帧独立输出 logit,实现 parallel multi-duration discrimination。
Figure 2 解读:对比 one-step diffusion forcing (DF) 与本文方法的计算效率。DF 在使用 KV cache 时,每个 AR step 仍需对两帧进行 compute(当前帧 + 上一帧作为 condition),而本文方法每步只需 compute 一帧,因为上一帧的信息已通过 channel concatenation 回收到输入中,KV cache 保存了历史信息。这使得本文方法比 DF 快约 2 倍。
3.2 Causal Architecture 详细设计
Block Causal Attention:
- 将原始 bidirectional full attention 替换为 block causal attention
- Text tokens 只 attend to text tokens
- Visual tokens attend to 当前帧和之前所有帧的 tokens(受 sliding window 限制)
- Sliding window size N=30 latent frames,always attend to text tokens 和第一帧
Input Recycling:
- 每步输入:noise z ~ N(0,I) + condition c_t + recycled frame x_{t-1}(通过 channel concatenation)
- 第一步:recycled frame 为用户提供的 Frame 0
- 之后:recycled frame 为上一步 generator 的输出
Positional Embedding:
- 使用 3D RoPE(Rotary Position Embedding)
- Spatial 维度动态 stretch 支持多分辨率
- Temporal 维度固定间隔,支持任意长度生成
# Pseudocode: Autoregressive Inference with KV Cache
def aapt_inference(first_frame, conditions, text, num_steps):
"""
first_frame: 用户提供的初始帧 (latent)
conditions: list of per-frame conditions (pose, camera, etc.)
text: text embedding
num_steps: 生成的 latent frame 数量
"""
kv_cache = None
recycled_frame = first_frame # Frame 0 由用户提供
generated_frames = [first_frame]
for t in range(num_steps):
noise = torch.randn_like(recycled_frame) # 采样随机噪声
# Channel concatenation: [noise, condition, recycled_frame]
input_t = channel_concat(noise, conditions[t], recycled_frame)
# Single forward pass with KV cache
output_frame, kv_cache = generator(
input_t, text, kv_cache,
use_block_causal_attention=True,
sliding_window=30
)
generated_frames.append(output_frame)
recycled_frame = output_frame.detach() # 回收生成结果
# Decode all latent frames through causal 3D VAE
video = vae_decode(generated_frames) # 每个 latent frame -> 4 video frames
return video3.3 三阶段训练流程
训练分为三个顺序阶段:
Stage 1: Diffusion Adaptation
将预训练 bidirectional DiT 适配为 causal autoregressive 架构:
- 使用 flow-matching parameterization:
- Timestep 经过 shifting function:,其中
- Teacher-forcing:输入 ground-truth frames 作为 recycled input,output target shift by one frame
- 所有帧使用相同的 timestep(不同于 diffusion forcing 的 progressive noise)
- Loss: MSE on predicted velocity
# Pseudocode: Diffusion Adaptation (Stage 1)
def diffusion_adaptation_step(model, video_frames, text, conditions):
"""Teacher-forcing diffusion training for causal architecture adaptation"""
B, T = video_frames.shape[:2] # batch, num_latent_frames
# 采样统一 timestep(所有帧共享)
t = torch.rand(B) * 1.0 # t ~ U(0, 1)
t = shift(t, s=24) # shifting function
# 构造 noisy input
epsilon = torch.randn_like(video_frames)
x_t = (1 - t) * video_frames + t * epsilon # flow-matching interpolation
# Teacher-forcing: recycled input = ground-truth previous frames
recycled_inputs = video_frames[:, :-1] # frames 0 to T-2
noisy_inputs = x_t[:, 1:] # noisy frames 1 to T-1
targets = video_frames[:, 1:] # target frames 1 to T-1
# Forward pass with block causal attention (parallel, all steps at once)
pred_velocity = model(noisy_inputs, recycled_inputs, text, conditions)
loss = F.mse_loss(pred_velocity, epsilon - targets)
return loss训练细节:AdamW optimizer,lr=1e-5,weight decay=0.01。先在 736x416 训练 20k iter(batch=256),再加入 1280x720 训练 6k iter(batch=128),最后提升最长时长到 15 秒训练 4k iter(batch=32)。
Stage 2: Consistency Distillation
将多步 diffusion model 蒸馏为 one-step generator:
- 在 32 个 fixed steps 上进行 consistency distillation
- 继续使用 teacher-forcing
- 不使用 classifier-free guidance(CFG 在自回归设定下会产生 artifacts)
- 不使用 EMA moving average
- 训练 5k iterations
# Pseudocode: Consistency Distillation (Stage 2)
def consistency_distillation_step(student, teacher, video_frames, text, conditions):
"""Distill multi-step diffusion to one-step generation"""
# 采样相邻的 timestep pair (t_n, t_{n+1}) from 32 fixed steps
t_n, t_n1 = sample_adjacent_timesteps(num_steps=32, s=24)
epsilon = torch.randn_like(video_frames[:, 1:])
x_tn1 = (1 - t_n1) * video_frames[:, 1:] + t_n1 * epsilon
# Teacher: one ODE step from t_{n+1} to t_n
with torch.no_grad():
x_tn = ode_step(teacher, x_tn1, t_n1, t_n, video_frames[:, :-1], text)
# Student: predict clean frame from both noise levels
pred_from_tn1 = student.predict_clean(x_tn1, t_n1, video_frames[:, :-1], text)
pred_from_tn = student.predict_clean(x_tn, t_n, video_frames[:, :-1], text)
loss = F.mse_loss(pred_from_tn1, pred_from_tn.detach())
return lossStage 3: Adversarial Training
核心阶段,使用 student-forcing + adversarial objective:
- Generator:从 consistency distillation weights 初始化
- Discriminator:从 diffusion adaptation weights 初始化,相同 causal architecture,noise channel 替换为 frame input,添加 per-frame logit output projection
- Student-Forcing:generator 仅以第一帧为 GT,之后回收自己的生成结果
- Loss:R3GAN (relativistic pairing loss) + R1/R2 regularization
其中 generator 用 ,discriminator 用 。
R1/R2 regularization:
其中 ,。
# Pseudocode: Student-Forcing Adversarial Training (Stage 3)
def adversarial_training_step(generator, discriminator, first_frame, real_video, text, conditions):
"""
Student-forcing: generator recycles its own output, mimicking inference.
Discriminator evaluates in parallel with per-frame logits.
"""
num_frames = real_video.shape[1]
# === Generator: Student-Forcing Autoregressive Generation ===
generated_frames = []
recycled = first_frame # Only GT input
kv_cache_g = None
for t in range(num_frames):
noise = torch.randn(...)
input_t = channel_concat(noise, conditions[t], recycled.detach()) # detach recycled frame
frame_t, kv_cache_g = generator(input_t, text, kv_cache_g)
generated_frames.append(frame_t)
recycled = frame_t # Recycle own output (student-forcing)
fake_video = torch.stack(generated_frames, dim=1)
# === Discriminator: Parallel evaluation, per-frame logits ===
# Discriminator uses block causal attention, processes all frames at once
fake_logits = discriminator(fake_video, text, conditions) # [B, T] per-frame logits
real_logits = discriminator(real_video, text, conditions) # [B, T] per-frame logits
# R3GAN relativistic pairing loss
loss_G = -torch.mean(torch.log(torch.sigmoid(fake_logits - real_logits)))
loss_D = -torch.mean(torch.log(torch.sigmoid(real_logits - fake_logits)))
# R1 & R2 regularization (approximated)
loss_R1 = approx_r1(discriminator, real_video, sigma=0.1, lam=1000)
loss_R2 = approx_r2(discriminator, fake_video, sigma=0.1, lam=1000)
return loss_G, loss_D + loss_R1 + loss_R23.4 Long-Video Training
解决长视频训练数据稀缺问题的关键策略:
- Generator 生成长视频(e.g. 60 秒)
- 切分为短片段(e.g. 10 秒),保留 1 秒 overlap 以鼓励片段连续性
- Discriminator 对每个片段独立评估(与真实视频片段对比)
- Generator 每次只生成一个片段供 discriminator 评估,通过 detached KV cache 连接片段
- 每个片段评估后即反向传播,累积梯度
# Pseudocode: Long-Video Training Strategy
def long_video_training(generator, discriminator, first_frame, text, conditions,
total_duration=60, segment_duration=10, overlap=1):
"""Train on long videos by segment-wise discriminator evaluation"""
total_loss_G, total_loss_D = 0, 0
kv_cache_g = None
recycled = first_frame
segments = split_into_segments(total_duration, segment_duration, overlap)
for seg_idx, (start, end) in enumerate(segments):
# Generator produces one segment autoregressively
segment_frames = []
for t in range(start, end):
noise = torch.randn(...)
input_t = channel_concat(noise, conditions[t], recycled.detach())
frame_t, kv_cache_g = generator(input_t, text, kv_cache_g)
segment_frames.append(frame_t)
recycled = frame_t
fake_segment = torch.stack(segment_frames, dim=1)
# Sample a real segment of same duration from dataset
real_segment = sample_real_segment(segment_duration)
# Discriminator evaluates this segment
loss_G, loss_D = adversarial_loss(discriminator, fake_segment, real_segment, text)
loss_G.backward() # Backprop per segment for memory efficiency
loss_D.backward()
total_loss_G += loss_G.item()
total_loss_D += loss_D.item()
# Detach KV cache for next segment (cut gradient graph)
kv_cache_g = detach_kv_cache(kv_cache_g)
optimizer_G.step()
optimizer_D.step()训练 schedule:
- 先不使用 long-video extension 训练 500 updates(5-10s 视频,lr=3e-6,batch=256)
- 加入 long-video extension(overlap 1s,max 19s),训练 500 updates
- 扩展 5 次(max 55s),降低 batch 到 64,lr 降至 1e-5
- 总计约 256 H100 GPUs 训练约 7 天
3.5 Teacher-Forcing vs Student-Forcing
Figure 8 (Teacher-Forcing Adversarial Training) 解读:展示 teacher-forcing 模式下的对抗训练。Generator 输入 ground-truth frames I1, I2, I3,独立生成 O2, O3, O4。注意 O3 只与 I2 有关联,与 O2 无关(因为输入是 GT 而非自身生成结果)。Discriminator 需要独立评估每个生成结果及其正确的依赖关系。这种模式下训练与推理存在 distribution gap——训练时输入是干净的 GT frame,推理时输入是带误差的生成 frame——导致 error accumulation 在推理时快速恶化。
3.6 Interactive Applications
本文在两个交互式应用场景进行了验证:
- Pose-Conditioned Virtual Human Generation:从训练视频中提取人体姿态,编码为 per-frame condition 输入模型。
- Camera-Controlled World Exploration:使用修改后的 Plucker embeddings 编码相机位姿,支持用户实时控制相机探索场景。对 CameraCtrl II 做了多处改进:相对位移(frame-to-frame 而非 relative-to-first)、直接编码 camera ray origin+direction、输入缩放到 1 std、随机初始化新 channel projection。
4. Experimental Setup (实验设置)
4.1 模型配置
| 配置项 | 设置 |
|---|---|
| Backbone | MMDiT (8B parameters, 36 transformer blocks) |
| Discriminator | 同 Generator 架构 (8B),总训练参数 16B |
| VAE | Causal 3D convolution VAE,temporal 压缩 4x,spatial 压缩 8x |
| Latent frame | 1 latent frame = 4 video frames |
| Attention window | N=30 latent frames (always attend to text + first frame) |
| Block causal attention | Flash Attention 3, for-loop 实现 |
| Positional encoding | 3D RoPE (spatial dynamic stretch, temporal fixed interval) |
| Parallelism | FSDP (ZERO 2 for generator, ZERO 3 for discriminator) + Ulysses context parallel |
4.2 训练配置
| 阶段 | Optimizer | LR | Iterations | Batch Size | 数据 |
|---|---|---|---|---|---|
| Diffusion Adaptation | AdamW | 1e-5 | 20k+6k+4k | 256/128/32 | 5s clips (736x416 → +1280x720 → 15s) |
| Consistency Distillation | AdamW | 1e-5 | 5k | 同上 | 同上 |
| Adversarial Training | RMSProp (D) | 3e-6 (G), APT setting (D) | ~1500+ | 256→64 | 5-10s → 55s |
4.3 评估基准
- VBench-I2V:120-frame (短视频) 和 1440-frame (1分钟长视频) 两个设定
- Pose-Conditioned Human Video:AKD (Average Keypoint Distance), IQA, ASE, FID, FVD
- Camera-Conditioned World Exploration:FVD, Mov, Trans, Rot, Geo, Apr
- Baselines:CausVid, Wan2.1, Hunyuan, MAGI-1, SkyReel-V2, 及自身 diffusion baseline
4.4 轻量化 VAE Decoder
为适配实时推理,训练了轻量化 VAE decoder:
- Residual blocks per resolution: 3 → 2
- Channels: [128, 256, 512, 512] → [64, 128, 256, 512]
- 速度提升约 3x,无明显质量下降
5. Experimental Results (实验结果)
5.1 主要定量结果
Table 1: VBench-I2V 定量对比 (736x416)
| Frames | Method | Temporal Quality | Frame Quality | Subject Consist. | Background Consist. | Motion Smooth. | Dynamic Degree | Aesthetic | Imaging | I2V Subject | I2V Background |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 120 | CausVid | 92.00 | 65.00 | - | - | - | - | - | - | - | - |
| 120 | Wan 2.1 | 87.95 | 66.58 | 93.85 | 96.59 | 97.82 | 39.11 | 63.56 | 69.59 | 96.82 | 98.57 |
| 120 | Hunyuan | 89.80 | 64.18 | 93.06 | 98.53 | 98.80 | 54.80 | 60.58 | 67.78 | 97.71 | 97.97 |
| 120 | Ours (Diffusion) | 90.40 | 66.08 | 94.58 | 96.76 | 98.80 | 52.52 | 62.44 | 69.71 | 97.89 | 99.14 |
| 120 | Ours (AAPT) | 89.51 | 66.58 | 96.22 | 96.66 | 99.19 | 42.44 | 62.09 | 71.06 | 98.60 | 99.36 |
| 1440 | SkyReel-V2 | 82.19 | 53.67 | 78.43 | 86.38 | 99.28 | 47.15 | 53.68 | 53.65 | 96.50 | 98.07 |
| 1440 | MAGI-1 | 80.79 | 60.01 | 82.23 | 89.27 | 98.54 | 25.45 | 52.26 | 67.75 | 96.90 | 98.13 |
| 1440 | Ours (Diffusion) | 86.65 | 60.49 | 82.38 | 89.48 | 98.29 | 66.26 | 56.46 | 64.51 | 95.01 | 97.72 |
| 1440 | Ours (AAPT) | 89.79 | 62.16 | 87.15 | 89.74 | 99.11 | 76.50 | 56.77 | 67.55 | 96.11 | 97.52 |
关键发现:
- 120-frame 生成:AAPT 在 frame quality、conditioning scores 方面全面领先,但 temporal quality 略有下降(adversarial training 的 trade-off)
- 1440-frame 生成:AAPT 在 quality 指标上全面最优,显著优于 SkyReel-V2 和 MAGI-1
- CausVid temporal quality 高是因为其训练在 12fps 数据上,dynamic degree 更高
Table 4: 延迟与吞吐量对比
| Method | Params | H100 | Resolution | NFE | Latency | FPS |
|---|---|---|---|---|---|---|
| CausVid | 5B | 1x | 640x352 | 4 | 1.30s | 9.4 |
| Ours | 8B | 1x | 736x416 | 1 | 0.16s | 24.8 |
| MAGI-1 | 24B | 8x | - | 8 | 7.00s | 3.43 |
| SkyReelV2 | 14B | 8x | 960x544 | 60 | 4.50s | 0.89 |
| Ours | 8B | 8x | 1280x720 | 1 | 0.17s | 24.2 |
核心优势:单 H100 上 736x416@24fps,延迟仅 0.16 秒;8xH100 上 1280x720@24fps,延迟 0.17 秒。比 CausVid 快约 8 倍(延迟),FPS 高 2.6 倍。
5.2 定性对比
Figure 3 解读:1440-frame(1 分钟)VBench-I2V 生成的定性对比。(a) SkyReel-V2 在 20-30 秒后出现明显 error accumulation,人物面部和背景严重退化。(b) MAGI-1 同样在中后期画面崩溃。(c) 自身 diffusion baseline(使用 extension 方式)也出现退化。(d) AAPT 但不使用 long-video training,在 10 秒后开始 drift(说明 long-video training 的必要性)。(e) 完整 AAPT 模型在整个 60 秒内保持稳定的画面质量。
5.3 交互式应用结果
Pose-Conditioned Virtual Human (Table 2):
- 在 6 种方法中,pose accuracy (AKD) 排名第二(仅次于 OmniHuman-1)
- 视觉质量 (IQA, ASE) 排名第二或第三,与 CyberHost 接近
Camera-Conditioned World Exploration (Table 3):
- 在 6 个指标中 3 个达到 SOTA(FVD, Trans, Geo)
- 总体与 CameraCtrl2 接近,但 FVD 大幅领先(61.33 vs 73.11)
5.4 Ablation Studies
Long-Video Training Duration (Table 5):
| Training Duration | Temporal Quality | Frame Quality |
|---|---|---|
| 10s | 85.86 | 57.92 |
| 20s | 85.60 | 65.69 |
| 60s | 89.79 | 62.16 |
训练 60 秒的模型显著优于仅训练 10 秒的模型,验证了 long-video training 的有效性。
Teacher-Forcing vs Student-Forcing:Teacher-forcing adversarial training 的模型在推理时几帧后就开始 drift,无法生成有意义的内容。Student-forcing 是减少 error accumulation 的关键。
5.5 Limitations
- Subject/Scene 一致性:长视频中可能出现主体和场景的变化,源于 generator 和 discriminator 都使用 basic sliding window
- Defect 持续性:one-step 生成偶尔产生缺陷,一旦出现会因 temporal consistency 持续存在
- 训练成本:long-video adversarial training 过程较慢
- 5 分钟以上:zero-shot 外推到 5 分钟可以生成内容,但有 artifacts