FramePack: Frame Context Packing and Drift Prevention in Next-Frame-Prediction Video Diffusion Models

Authors: Lvmin Zhang, Shengqu Cai, Muyang Li, Gordon Wetzstein, Maneesh Agrawala Affiliations: Stanford University, MIT

1. Motivation (研究动机)

核心问题: Next-Frame Prediction 中的 Forgetting-Drifting 困境

Video diffusion model 中的 next-frame prediction 范式面临两个核心挑战:

  1. Forgetting (遗忘): 模型难以记住早期内容和维持时序依赖关系。随着生成帧数增加, 早期帧的信息逐渐丢失。
  2. Drifting (漂移): 由于 error accumulation (误差累积), 视觉质量随时间退化, 也称为 exposure/observation bias。

这两个问题之间存在根本矛盾:

  • 增强 memory 来对抗 forgetting 更多 error propagation 加剧 drifting
  • 减弱 temporal dependency 来对抗 drifting 削弱历史信息 加剧 forgetting

为什么现有方法不够好

  • Naive full-context: 编码所有历史帧, 但 transformer 的 quadratic attention complexity 使其在帧数 T 较大时计算不可行 (每帧 context length L_f ~ 1560 for 480p, 总长 L = L_f * (T+S))
  • DiffusionForcing: 通过对历史帧加噪来减少 drifting, 但削弱了 memory
  • Causal attention: 加速推理但仍面临 drift 问题
  • Anchor frame: 可作为规划元素, 但不解决根本的 context 压缩问题

FramePack 的洞察

视频帧之间存在大量 temporal redundancy (时序冗余), 且不同帧对预测下一帧的重要性不同。距离预测目标越近的帧越重要, 越远的帧可以用更高压缩率编码。这种 importance-based progressive compression 可以将总 context length 收敛到固定上界, 使得计算量与视频长度无关。


2. Idea (核心思想)

FramePack 的核心思想分为两部分:

2.1 Frame Context Packing (抗遗忘)

将输入帧按重要性排序后, 用不同压缩率的 patchify kernel 编码:

  • 最重要的帧 (时间上最近): 使用小 kernel (如 1x2x2), context 最长
  • 次重要的帧: 使用中等 kernel (如 2x4x4)
  • 最不重要的帧 (时间上最远): 使用大 kernel (如 4x8x8), context 最短

总 context length 遵循 geometric series, 当帧数 T 无穷时收敛到有限值 L = (S + lambda/(lambda-1)) * L_f。

2.2 Drift Prevention (抗漂移)

三种互补的抗漂移策略:

  1. Planned Endpoints: 先生成首尾帧, 再填充中间, 将单向 causal chain 变为双向 conditioning
  2. Inverted Sampling Order: 对 image-to-video 任务, 反转生成顺序使每次生成都朝向高质量输入帧
  3. History Discretization: 用 K-Means codebook 将历史帧量化为离散 token, 减少 train-inference distribution gap

3. Method (方法)

3.1 Packing Frame Context

Figure 1 解读: 该图展示了 FramePack 的多种变体。(a) Time-proximity-based permutation: 输入帧按时间距离排序, 最近的帧 F_0 获得最长 context, 最远的帧 F_{T-1} 被最大压缩, 然后送入 DiT 预测下一帧。(b) Feature-similarity-based permutation: 不按时间而按 latent similarity 排序, 与预测目标最相似的帧排在最前。(c) GPU memory layout 展示了多种 packing schedule: Geometric progression (标准几何级数压缩), Temporal level duplication (在某些层级重复帧), Level duplication (复制某些压缩层级), Temporal kernel (用时间维度的 kernel 压缩连续帧), Progression with important start (起始帧给予重要位置), Symmetric progression (首尾帧同等重要)。

压缩公式

每帧的 context length 由 importance 函数决定:

其中 lambda > 1 是压缩参数 (论文中主要使用 lambda=2)。总 context length:

当 T 无穷: L 收敛到 (S + lambda/(lambda-1)) * L_f, 即 context length 有上界。

Packing Schedule 实现

压缩通过修改 DiT 的 patchify kernel 实现。3D kernel (p_f, p_h, p_w) 中各维度的乘积决定压缩率。例如压缩率 64 可由 (1,8,8), (4,4,4), (16,2,2) 等 kernel 实现。

# Pseudocode: Frame Context Packing
def pack_frame_context(frames, target_frame, lambda_=2):
    """
    frames: list of T history frames [F_0, F_1, ..., F_{T-1}]
             F_0 is most recent, F_{T-1} is oldest
    target_frame: X (frame to predict)
    lambda_: compression ratio base (default=2)
    """
    packed_contexts = []
 
    # Target frame: no compression (kernel 1x2x2 = standard patchify)
    packed_contexts.append(patchify(target_frame, kernel=(1, 2, 2)))
 
    # Most recent frame F_0: minimal compression (kernel 1x2x2)
    packed_contexts.append(patchify(frames[0], kernel=(1, 2, 2)))
 
    # F_1, F_2: 2x compression (kernel 2x4x4)
    packed_contexts.append(patchify(frames[1:3], kernel=(2, 4, 4)))
 
    # F_3 to F_18: 4x compression (kernel 4x8x8)
    packed_contexts.append(patchify(frames[3:19], kernel=(4, 8, 8)))
 
    # Tail frames: global average pool + largest kernel
    # ...
 
    # Total context = sum of all packed lengths (bounded)
    return torch.cat(packed_contexts, dim=1)  # along token dim

Independent Patchifying Parameters

不同压缩率使用独立的 neural network 层参数, 不共享权重。对于高压缩率 (如 16x32x32), 先 downsample 到 2x2x2 再用最大 kernel (8x16x16)。权重初始化通过对 pretrained patchify projection (如 HunyuanVideo 的 (2,4,4) kernel) 进行 interpolation。

RoPE Alignment

不同压缩 kernel 产生不同 context length, 需要对 RoPE (Rotary Position Embedding) 进行对齐。方法是对 RoPE phase 做 average pooling (downsample) 以匹配对应的 compression kernel。

代码映射: process_input_hidden_states()

diffusers_helper/models/hunyuan_video_packed.pyHunyuanVideoTransformer3DModelPacked 中:

# 实际代码中的 frame packing 核心逻辑
def process_input_hidden_states(self, ...):
    # Base latents: 1x compression (standard patchify)
    # clean_latents_2x: 2x4x4 kernel with center downsampling
    # clean_latents_4x: 4x8x8 kernel with center downsampling
    # 每种压缩率使用独立的 projection 层

3.2 Feature Similarity Based Packing

除了 time proximity, 还可以用 feature similarity 排序帧:

以及 hybrid 方法:

Hybrid sorting 适合 world model 数据集 (如游戏场景), 可以回忆之前访问过的场景。

3.3 Drift Prevention

Figure 2 解读: 该图展示了 anti-drifting 的采样和训练方法。(a) Vanilla: 标准顺序生成, 每次迭代向右扩展, 容易累积误差。(b) Anti-drifting: 先生成首尾 section, 再逐步填充中间, 形成双向 conditioning。(c) Inverted anti-drifting: 针对 image-to-video, 反转方向使每次生成朝向输入图像 (高质量锚点)。(d) Planned anti-drifting: 用多个 prompt 在不同位置设置 endpoint, 然后填充间隔, 支持复杂叙事。(e) History discretization: 用 K-Means 对 latent dataset 建立 codebook, 将连续历史帧量化为 discrete indices, 推理时用最近邻 codebook entry 替换历史帧, 减少 train-test 分布差异。

3.3.1 Planned Endpoints and Adjusted Sampling Order

核心思想: 严格的 causal chain P(X_t|X_{t-1}) 容易 drift, 而双向模型 P(X_t|X_{t1}, X_{t2}) (t1 < t < t2) 更稳定。

# Pseudocode: Inverted Anti-drifting Sampling (for image-to-video)
def inverted_antidrift_sampling(input_image, prompt, total_sections):
    """
    input_image作为第一帧 (高质量锚点)
    反转生成方向: 从最远处向输入图像方向生成
    """
    start_latent = vae_encode(input_image)
    history_latents = []  # 累积生成的帧
 
    # 生成 padding 序列: [3, 2, 2, ..., 2, 1, 0]
    # 3=最远, 0=最近, 反转了 vanilla 的顺序
    latent_paddings = [3] + [2] * (total_sections - 3) + [1, 0]
 
    for padding in latent_paddings:
        is_last = (padding == 0)
 
        # 从 history 中取出不同压缩率的 context
        clean_latents_post, clean_latents_2x, clean_latents_4x = \
            split_history(history_latents)
 
        # 始终包含 input_image 作为条件
        clean_latents = concat(start_latent, clean_latents_post)
 
        # 生成新 section
        new_latents = sample_hunyuan(
            clean_latents=clean_latents,         # 1x 压缩
            clean_latents_2x=clean_latents_2x,   # 2x 压缩
            clean_latents_4x=clean_latents_4x,   # 4x 压缩
            latent_padding=padding
        )
 
        # 将新生成的帧 prepend 到 history
        history_latents = concat(new_latents, history_latents)
 
    return history_latents

3.3.2 History Discretization

将历史帧量化为离散 token 以减少 train-test distribution gap:

其中 Omega 是通过 K-Means 从训练数据 latent 中学到的 codebook (K 个 cluster center)。

# Pseudocode: History Discretization
def discretize_history(history_frames, codebook, K=128):
    """
    history_frames: [B, C, T, H, W] continuous latent history
    codebook: [K, C] learned via K-Means on training latents
    K: codebook size (K=128 recommended)
    """
    B, C, T, H, W = history_frames.shape
 
    # Flatten spatial dims
    pixels = history_frames.reshape(B, C, -1).permute(0, 2, 1)  # [B, T*H*W, C]
 
    # Find nearest codebook entry for each pixel
    distances = torch.cdist(pixels, codebook.unsqueeze(0))  # [B, T*H*W, K]
    indices = distances.argmin(dim=-1)  # [B, T*H*W]
 
    # Replace with codebook entries
    discrete_history = codebook[indices]  # [B, T*H*W, C]
    discrete_history = discrete_history.permute(0, 2, 1).reshape(B, C, T, H, W)
 
    return discrete_history

K 的选择:

  • K=1: 历史变成单一颜色, 完全消除 drift 但 section 间无连贯性
  • K=infinity: 等同于无离散化, drift 依旧
  • K=128: 实验表明是好的平衡点, 强 drift reduction + 最小训练难度

3.4 Soft Append for Video Stitching

推理时, 相邻生成 section 之间存在 overlap, 使用线性插值进行平滑拼接:

# 代码映射: diffusers_helper/utils.py
def soft_append_bcthw(current_pixels, history_pixels, overlapped_frames):
    """
    current_pixels: [B, C, T1, H, W] 当前生成的 section (pixel space)
    history_pixels: [B, C, T2, H, W] 之前累积的 history
    overlapped_frames: int, overlap 区域的帧数
 
    在 overlap 区域用线性权重混合:
    weight = linspace(0, 1, overlapped_frames)
    blended = current * (1-weight) + history * weight
    """
    # 非 overlap 部分直接拼接
    # overlap 部分线性混合
    return concatenated_video

3.5 完整推理流程

# Pseudocode: Complete FramePack Inference Pipeline
def framepack_inference(input_image, prompt, total_seconds=30,
                         latent_window_size=9, steps=25):
    # 1. Encode prompt
    llama_vec, clip_pooler = encode_prompt(prompt)
 
    # 2. Encode input image
    start_latent = vae_encode(input_image)  # [1, 16, 1, H/8, W/8]
    image_embed = clip_vision_encode(input_image)
 
    # 3. Initialize history
    # Shape: [1, 16, 1+2+16, H/8, W/8] = 19 frames of context
    history_latents = zeros(1, 16, 1+2+16, H//8, W//8)
    history_pixels = None
    total_generated = 0
 
    # 4. Calculate section count
    total_latent_frames = total_seconds * 30 / 4  # 30fps, VAE temporal compress 4x
    total_sections = ceil(total_latent_frames / latent_window_size)
 
    # 5. Generate sections (inverted order for anti-drifting)
    paddings = [3] + [2]*(total_sections-3) + [1, 0]
 
    for padding in paddings:
        is_last = (padding == 0)
 
        # 5a. Split history into multi-scale context
        post, ctx_2x, ctx_4x = history_latents[:,:,:1+2+16].split([1,2,16], dim=2)
        clean = cat([start_latent, post], dim=2)  # always include input image
 
        # 5b. Compute frame indices for RoPE
        indices = compute_indices(padding, latent_window_size, is_last)
 
        # 5c. Denoise with multi-scale conditioning
        generated = sample_hunyuan(
            clean_latents=clean,           # 1x: start_image + recent 1 frame
            clean_latents_2x=ctx_2x,       # 2x: 2 frames compressed
            clean_latents_4x=ctx_4x,       # 4x: 16 frames compressed
            latent_indices=indices,
            steps=steps, cfg=cfg
        )
 
        # 5d. Accumulate history
        total_generated += generated.shape[2]
        history_latents = cat([generated, history_latents], dim=2)
 
        # 5e. Decode and stitch with soft append
        real_history = history_latents[:,:,:total_generated]
        if history_pixels is None:
            history_pixels = vae_decode(real_history)
        else:
            section_frames = latent_window_size * 2 + (1 if is_last else 0)
            overlap = latent_window_size * 4 - 3
            current_pixels = vae_decode(real_history[:,:,:section_frames])
            history_pixels = soft_append_bcthw(current_pixels, history_pixels, overlap)
 
    # 6. Save as video
    save_mp4(history_pixels, fps=30)

4. Experimental Setup (实验设置)

Base Model

  • HunyuanVideo (13B parameters) 和 Wan 作为 base model
  • 两者都天然支持 next-frame-section prediction
  • 使用 FramePack 进行 finetune, 无需修改原始模型架构

Training

  • Hardware: H100 GPU clusters (单节点 8xA100-80G)
  • Batch size: FramePack 实现 batch size ~64 (480p, 13B model), 对比 full-video generation 显著更大
  • Training method: LoRA training, window size 2-3 (或 batch size 32, window size 4-5)
  • 数据: 遵循 LTXVideo 的 dataset collection pipeline, 多分辨率多质量等级

Evaluation

  • Test set: 512 real user prompts (text-to-video) + 512 image-prompt pairs (image-to-video)
  • 视频长度: 30 seconds (long video), 5 seconds (short video)
  • Global Metrics: Clarity (MUSIQ), Aesthetic (LAION), Motion (RAFT-VBench), Dynamic (RAFT-VBench), Semantic (ViCLIP), Anatomy (ViT-VBench), Identity (ArcFace+RetinaFace)
  • Drifting Metrics: Start-end contrast Delta_drift^M = |M(V_start) - M(V_end)|, 对比视频前15%和后15%帧的质量差异
  • Human Assessment: A/B test, ELO-K32 scoring, 每个 ablation 至少 100 次评估

Ablative Naming Convention

  • td_f16k4f4k2f1k1_g9: td=delete tail, f16k4=16帧用k4 kernel, f4k2=4帧用k2, f1k1=1帧用k1, g9=生成9帧
  • +D: history discretization
  • _x_f1k1: anti-drifting with endpoint frame

5. Experimental Results (实验结果)

5.1 Ablation Study (Table 1)

主要发现:

发现细节
Inverted anti-drifting 最优在 7 个 metric 中的 4 个取得最佳, 所有 drifting metric 最佳, human preference (ELO) Rank 1
Dynamic range 差异Inverted anti-drifting 的 dynamic range 较小; vanilla + discrete history 的 dynamic range 最大
Vanilla sampling 的 dynamic score 最高是假象高 dynamic 实际来自 drifting effect 而非真实动态, ELO 分较低
Discrete history 强竞争力Vanilla + discrete history 获得高 human score, 且 dynamic range 大, 是另一强方案
具体配置差异小同一 sampling approach 内部不同配置差异较小且随机, 整体 architecture 贡献更大

History Discretization Parameter K

  • K=128 给出强 drift reduction + 最小训练困难
  • 更高 K: 更强抗 drift 但训练更难
  • 更低 K: section 间连贯性下降

5.2 Comparison with Alternative Methods (Table 2)

MethodELODrifting特点
Repeating image-to-video1003 (Rank 5)高 drift简单重复, 质量退化明显
Anchor frames (StreamingT2V)1173 (Rank 2)中等使用锚帧稳定
Causal attention (CausVid)1007 (Rank 4)中等快速但 drift
DiffusionForcing (best variant)1174 (Rank 2)中等sigma_train=0.1, sigma_test=0.1 效果最好
History guidance1112 (Rank 3)中高增强 memory 但加速 error accumulation
Inverted anti-drifting (ours)1220 (Rank 1)最低所有 drift metric 最优
Vanilla + discrete history (ours)1224 (Rank 1)高 dynamic + 低 drift

5.3 关键优势

  1. 计算效率: Next-frame prediction 每步生成更小的 tensor, 对比 full-video generation 显著节省显存
  2. Batch size: 可在单节点上以类似 image diffusion 的 batch size 训练 13B video model
  3. 更平衡的 diffusion scheduler: 更少的 extreme flow shift timestep
  4. 消费级硬件: 6GB/8GB VRAM 笔记本 GPU 即可运行 13B model 处理上千帧
  5. 兼容性: 可直接 finetune 现有 pretrained video diffusion model (HunyuanVideo, Wan 等)