MemFlow: Flowing Adaptive Memory for Consistent and Efficient Long Video Narratives

Authors: Sihui Ji, Xi Chen, Shuai Yang, Xin Tao, Pengfei Wan, Hengshuang Zhao Affiliations: HKU, Kling Team (Kuaishou Technology), HKUST(GZ) arXiv: 2512.14699 GitHub: KlingTeam/MemFlow

1. Motivation (研究动机)

流式视频生成(streaming video generation)的核心挑战是在长序列中保持叙事一致性(narrative coherence),这对记忆机制的设计提出了很高要求:

  • 固定压缩策略失效:现有方法(如 LongLive 的 Frame Sink)只保留第一帧 chunk 的 KV 作为全局锚点,或通过固定压缩方案存储历史帧。然而不同的待生成 chunk 需要引用不同的历史内容——场景切换后可能需要召回很久之前的人物外观,而固定策略无法做到这种动态适配
  • 冗余主体问题:当 prompt 切换引入新角色时,LongLive 等方法因为 Frame Sink 只锚定第一帧,无法正确关联新 prompt 与历史中的相关视觉线索,导致反复引入冗余人物、背景漂移
  • 效率约束:直接将所有历史帧放入注意力窗口会使计算复杂度随序列长度线性增长,memory bank 的容量必须受到严格限制
  • 记忆与效率的矛盾:引入 memory bank 提升一致性,必然增加计算开销。如何在几乎不损失速度的前提下获得长程记忆?

核心矛盾:固定记忆策略无法为不同 prompt 动态提供语义相关的历史上下文,而引入动态记忆又会显著增加计算开销


2. Idea (核心思想)

MemFlow 的核心 insight:通过文本引导的语义检索动态更新 memory bank + 稀疏激活仅唤醒最相关的记忆 token,可以用极小的计算代价(7.9% 速度损失)实现自适应长程记忆

两个关键设计:

  1. Narrative Adaptive Memory (NAM):在生成每个新 chunk 之前,用当前 prompt 的文本 token 作为 query,计算与 memory bank 中所有历史帧的 cross-attention 相关性分数,检索语义最相关的 top-k 帧,同时将上一个 chunk 的首帧 KV 作为原型(prototype)压缩加入 bank,实现动态更新
  2. Sparse Memory Activation (SMA):生成过程中,对 memory bank 中的 KV 按 chunk 粒度计算与当前 query 的相关性,只激活 top-k 个最相关的 chunk 参与注意力计算,大幅减少计算量

最终:基于 Wan2.1-T2V-1.3B,单卡 H100 上 18.7 FPS,支持 60 秒多 prompt 交互式长视频生成,比 memory-free baseline(LongLive)仅慢 7.9%。


3. Method (方法)

3.1 整体架构

Figure 2 解读:MemFlow 的整体框架由三部分组成——左侧是 Narrative Adaptive Memory (NAM) 模块:给定当前 prompt 和 KV cache memory bank,先用 prompt 的文本 token 查询 memory bank 中的视觉 token(Semantic Retrieval),选出语义最相关的历史帧;再将上一个 chunk 的首帧 KV 作为 Prototypical KV 加入 bank,淘汰不相关帧(Redundant Removal),完成 memory 更新。中间是 AR-Diffusion 框架:更新后的 memory 与 local window 的 KV cache 拼接送入自回归扩散模型生成当前 chunk。下方是 Sparse Memory Activation (SMA):将 memory bank 划分为 frame-level chunk,计算每个 chunk 与当前 query 的相关性分数,只选择 top-k 个最相关的 chunk 参与注意力。右侧展示了自回归生成流程——每生成一个 chunk,memory bank 就动态更新一次。

MemFlow 构建在 hybrid AR-diffusion 框架之上(继承 LongLive 的 Self-Forcing + DMD 蒸馏流程):

  • 每次生成 帧,以最近 帧为局部上下文(local window)
  • 自回归过程中自然产生 KV cache,直接作为 memory bank 的基础数据结构
  • 标准注意力覆盖 帧,引入 memory bank 后扩展到 帧( 为 bank 中的帧数)

训练机制:

  • 采用 Self-Forcing + DMD 蒸馏(Stage 1),将双向模型转为 causal AR few-step 模型
  • Streaming Long Tuning(Stage 2):在 60 秒长序列上训练,每次迭代生成 5 秒 clip,关键是将 NAM 和 SMA 集成到训练中,让模型学会如何从自身生成的历史中检索和利用记忆

3.2 Narrative Adaptive Memory (NAM)

NAM 的目标是根据当前 prompt 动态选择最相关的历史帧,由两个协同机制组成:

(1)Semantic Retrieval(语义检索)

在每个 transformer 层 ,用当前 prompt 的文本 query 与 memory bank 中第 帧的 KV cache 计算语义相关性分数:

其中 是 mean pooling,将注意力权重压缩为标量分数 。然后选择 top- 个分数最高的帧保留。

设计直觉:cross-attention 分数天然衡量文本与视觉之间的语义对齐程度,分数高说明这帧与当前 prompt 语义相关。例如,prompt 提到”小孩在海滩上跑”,memory bank 中包含海滩场景的帧会得到高分。

(2)Redundant Removal(冗余消除 / 原型压缩)

短视频 chunk 内相邻帧高度冗余,无需逐帧存储。MemFlow 用极简策略:只取上一个 chunk 的首帧 KV 作为整个 chunk 的原型(Prototypical KV),将其与检索到的历史帧拼接形成更新后的 memory bank。

相比 context merging 等计算密集型压缩方案,首帧原型策略零额外计算且实测效果等价。

3.3 Sparse Memory Activation (SMA)

Figure 4 解读(上半部分):对比四种记忆机制在 60 秒多 prompt 场景下的表现。“w/o Memory”(无记忆)在后续 prompt 中主体漂移严重;“Frame Sink”(LongLive 的首帧锚定)前期一致但后期主体丢失;“NAM”(仅语义检索)能跨 prompt 保持主体但不够高效;“NAM+SMA”(完整 MemFlow)在保持一致性的同时效率最优。

Figure 5 解读(下半部分):左图分析 NAM 中 memory bank 容量 的影响——(约为 local window 的一半)表现最稳定, 反而低于 baseline, 出现剧烈波动。过大的 bank 导致全局上下文比例失衡,干扰短期叙事流。

SMA 的目标是在不损失记忆质量的前提下减少注意力计算量

  1. 将 memory bank 的 按帧切分为 个 chunk(每个 chunk 1560 tokens,对应一帧的 latent)
  2. 对当前生成 chunk 的 query 和每个 memory chunk 的 key 分别做 mean pooling,得到紧凑描述子
  3. 计算相关性分数:
  4. 选择 top- 个分数最高的 chunk:
  5. 注意力只在选中的 chunk 上计算:
def dynamic_topk_routing_attention(query, key, value, chunk_size=1560, top_k=3):
    """Sparse Memory Activation: only attend to top-k most relevant memory chunks.
 
    Args:
        query: [B, H, Lq, D] current chunk visual query
        key:   [B, H, Lm, D] memory bank keys (sink + bank)
        value: [B, H, Lm, D] memory bank values
        chunk_size: tokens per frame (1560 for 832x480 latent)
        top_k: number of memory chunks to activate
    """
    B, H, Lm, D = key.shape
    num_chunks = Lm // chunk_size
 
    # Step 1: Reshape into frame-level chunks
    k_chunks = key.reshape(B, H, num_chunks, chunk_size, D)
    v_chunks = value.reshape(B, H, num_chunks, chunk_size, D)
 
    # Step 2: Mean-pool to get compact descriptors
    q_desc = query.mean(dim=2, keepdim=False)        # [B, H, D]
    k_desc = k_chunks.mean(dim=3)                    # [B, H, num_chunks, D]
 
    # Step 3: Compute relevance scores via inner product
    scores = torch.einsum('bhd,bhnd->bhn', q_desc, k_desc)  # [B, H, num_chunks]
 
    # Step 4: Select top-k chunks
    topk_indices = scores.topk(top_k, dim=-1).indices  # [B, H, top_k]
 
    # Step 5: Gather selected chunks
    topk_indices_exp = topk_indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, chunk_size, D)
    k_selected = k_chunks.gather(2, topk_indices_exp).reshape(B, H, top_k * chunk_size, D)
    v_selected = v_chunks.gather(2, topk_indices_exp).reshape(B, H, top_k * chunk_size, D)
 
    # Step 6: Standard attention only on selected memory
    output = flash_attention(query, k_selected, v_selected)
    return k_selected, v_selected

3.4 Memory Bank 压缩与更新(compress_kv_bank)

在推理过程中,memory bank 的更新通过 compress_kv_bank 方法实现:

def compress_kv_bank(kv_cache, new_k, new_v, crossattn_cache,
                     tokens_per_block=1560, memory_budget_in_blocks=3,
                     num_prototypes_in_blocks=1):
    """NAM memory update: retrieve relevant history + add new prototype.
 
    Args:
        kv_cache: current memory bank {k, v}
        new_k, new_v: prototypical KV from last chunk's first frame
        crossattn_cache: text KV for computing semantic scores
        tokens_per_block: 1560 tokens per latent frame
        memory_budget_in_blocks: max frames in bank (b=3)
        num_prototypes_in_blocks: frames from latest chunk (1)
    """
    # Step 1: Get text query from cross-attention cache
    text_query = crossattn_cache["k"]  # text token embeddings
 
    # Step 2: Compute semantic relevance per historical block
    num_blocks = kv_cache["k"].shape[2] // tokens_per_block
    scores = []
    for i in range(num_blocks):
        block_k = kv_cache["k"][:, :, i*tokens_per_block:(i+1)*tokens_per_block]
        attn_weights = softmax(text_query @ block_k.T / sqrt(d))
        score = attn_weights.mean()  # Aggregate to scalar
        scores.append(score)
 
    # Step 3: Select top-(budget - prototype) most relevant blocks
    num_keep = memory_budget_in_blocks - num_prototypes_in_blocks
    top_indices = torch.topk(torch.tensor(scores), num_keep).indices
 
    # Step 4: Concatenate: retrieved history + new prototype
    selected_k = gather_blocks(kv_cache["k"], top_indices)
    selected_v = gather_blocks(kv_cache["v"], top_indices)
    updated_bank_k = torch.cat([selected_k, new_k], dim=2)
    updated_bank_v = torch.cat([selected_v, new_v], dim=2)
 
    return {"k": updated_bank_k, "v": updated_bank_v}

3.5 注意力集成

最终的注意力计算将三部分 KV 拼接:

# Without SMA: concatenate all memory
if not self.SMA:
    k_cat = torch.cat([k_sink, k_bank, k_local], dim=1)  # sink + full bank + local window
    v_cat = torch.cat([v_sink, v_bank, v_local], dim=1)
    output = attention(q, k_cat, v_cat)
 
# With SMA: sparse activation on global memory, then concatenate with local
else:
    k_global = torch.cat([k_sink, k_bank], dim=1)
    v_global = torch.cat([v_sink, v_bank], dim=1)
    # Only keep top-k relevant chunks from global memory
    k_global, v_global = dynamic_topk_routing_attention(
        query=q, key=k_global, value=v_global,
        chunk_size=1560, top_k=3
    )
    k_cat = torch.cat([k_global, k_local], dim=1)
    v_cat = torch.cat([v_global, v_local], dim=1)
    output = attention(q, k_cat, v_cat)

其中:

  • k_sink:Frame Sink(第一帧 chunk 的 KV),继承自 LongLive
  • k_bank:NAM 检索+更新后的 memory bank( 帧的 KV)
  • k_local:local attention window(最近 帧的 KV)

4. Experimental Setup (实验设置)

4.1 评测场景

  • 多 prompt 60 秒交互式生成:用于检验跨 prompt 的主体一致性与叙事连贯性
  • 单 prompt 短视频/长视频:用于检验常规生成质量、语义保持和吞吐速度
  • 消融实验:用于比较不同 memory mechanism 的效果

4.2 基础模型与分辨率

  • Base model:Wan2.1-T2V-1.3B
  • Resolution:832×480
  • 长视频设置:60 秒交互式生成;每次迭代生成 5 秒 clip

4.3 评估指标

  • QualityConsistencyAestheticCLIP (0-10s / 50-60s)
  • FPSTotalSemantic
  • Subject ConsistencyBackground Consistency

4.4 训练与推理设置

  • 采用 Self-Forcing + DMD 蒸馏作为 Stage 1
  • Streaming Long Tuning 作为 Stage 2,在 60 秒长序列上进行训练
  • 论文未详细说明其它超参数

5. Experimental Results (实验结果)

5.1 多 Prompt 60 秒交互式生成

Figure 1 解读:三组对比示例。每组上面是 MemFlow,下面是 LongLive。场景涉及海滩、圣诞树、树林等多 prompt 切换。LongLive 在 prompt 切换后出现主体不一致(如人物外观变化、冗余角色出现),而 MemFlow 通过动态记忆检索始终保持主体一致性。

MethodQuality ↑Consistency ↑Aesthetic ↑CLIP (0-10s)CLIP (50-60s)
SkyReels-V281.5594.7256.8325.3120.91
Self Forcing83.9495.7458.4526.2421.07
LongLive84.2896.0559.8926.6324.11
FramePack84.4096.7759.4426.5121.62
MemFlow85.0296.6061.0726.3124.22

关键发现:

  • MemFlow 在质量美学两项均为最优,验证了记忆机制不仅提升一致性,还缓解了误差累积对画质的影响
  • Consistency 分数仅次于 FramePack(因为 FramePack 降低帧间动态以减少不一致,实际全局一致性不如 MemFlow)
  • CLIP Score 在后半段(40-60s)优势明显,说明 NAM 的语义检索在长序列中持续提供正确的历史上下文

Figure 3 解读:超市购物场景的 60 秒多 prompt 定性对比。SkyReels-V2 出现严重的主体不一致(人物面部漂移);FramePack 色调漂移;Self Forcing 与 LongLive 在后续 prompt 中引入了冗余角色。MemFlow 始终保持同一对男女的外观一致。

5.2 单 Prompt 短视频/长视频

ModelParamsResolutionFPS ↑Total ↑Quality ↑Semantic ↑
Wan2.11.3B832×4800.7884.2685.3080.09
CausVid1.3B832×48017.081.2084.0569.80
Self Forcing (chunk)1.3B832×48017.084.3185.0781.28
LongLive1.3B832×48020.384.8786.9776.47
MemFlow1.3B832×48018.785.1485.9581.90
  • MemFlow 的 Total Score (85.14)Semantic Score (81.90) 均为同级模型最优
  • FPS 18.7 虽然比 LongLive (20.3) 略低 7.9%,但仍然是实时级别
  • Semantic Score 大幅领先 LongLive(81.90 vs 76.47),验证了 NAM 的语义检索优势

5.3 消融实验

Memory MechanismSubject Consist. ↑Background Consist. ↑FPS ↑
w/o Memory94.4195.1523.5
Frame Sink97.6696.2020.3
NAM+SMA(完整模型)98.0196.7018.7
NAM(无 SMA)98.0596.5717.6

关键结论:

  • NAM 显著提升一致性:相比 Frame Sink,Subject Consistency 从 97.66 → 98.05
  • SMA 几乎无损提速:NAM 17.6 FPS → NAM+SMA 18.7 FPS(提升 6.3%),而一致性基本不变
  • Memory bank 容量 最优:过大的 bank 导致全局上下文比例失衡,反而损害短期叙事连贯性

5.4 局限与展望

论文提到的局限

  • 容量敏感性 的最优值高度依赖 local window 大小, 过大(如 6、9)反而导致性能下降甚至不稳定,说明全局-局部上下文比例需要精细平衡
  • 首帧原型假设:Redundant Removal 仅用首帧作为原型,在帧间运动剧烈时信息损失可能较大

潜在改进方向

  • 自适应 bank 容量:根据场景复杂度和 prompt 变化频率动态调整
  • 多帧原型:对运动剧烈的 chunk 使用多帧或加权融合原型,减少信息损失
  • 扩展到更高分辨率/更长时长:当前验证在 832×480 / 60 秒,scaling 到更大规模需要进一步验证 SMA 的效率优势
  • 与其他 base model 结合:论文声称兼容任何使用 KV cache 的流式生成模型,值得在更大模型上验证

5.5 代码映射(Code Mapping)

论文概念代码位置说明
整体模型定义model/dmd_switch.pyDMDSwitch(DMD)继承 DMD,添加 prompt 切换和 memory bank 支持
NAM: Semantic Retrievalwan/modules/causal_model.pycompress_kv_bank()计算 text query 与 memory bank 各 block 的 cross-attn 分数,topk 选择
NAM: Redundant Removalwan/modules/causal_model.pycompress_kv_bank()取上一 chunk 首帧 KV 为 prototype,与检索结果拼接
SMA: Sparse Memory Activationwan/modules/causal_model.pydynamic_topk_routing_attention()mean-pool → inner product → top-k chunk 选择 → sparse attention
注意力集成 (sink + bank + local)wan/modules/causal_model.py → forward 中的 torch.cat([k_sink, k_bank, k_local])SMA 模式下先 sparse select 再 cat
Memory bank 更新触发wan/modules/causal_model.py_apply_cache_updates_before()累积 token 超过一个 block (1560) 时触发 compress
交互式推理pipeline/interactive_causal_inference.py多 prompt 生成主循环,含 prompt switch + bank update
Streaming Long Trainingpipeline/streaming_switch_training.pyStage 2 训练:streaming rollout + prompt switch + NAM/SMA
训练入口train.py + train_init.sh (Stage 1) / train_long.sh (Stage 2)两阶段训练脚本
推理入口inference.py (单 prompt) / interactive_inference.py (多 prompt)推理脚本
配置文件configs/bank_size、local_attn_size 等超参数
Benchmarkprompts/interactive_benchmark.jsonl100 组叙事脚本,每组 6 个 10s prompt