1. Motivation (研究动机)

当前主流 LLM 的部署范式是”train once, deploy static”——权重在训练后就被冻结。这一设计天然地限制了模型对推理阶段动态输入流的自适应能力,在 long-horizon / streaming / 持续学习场景下表现明显不足。

Test-Time Training (TTT) 是一个很早就被提出(Sun et al., 2020)、近年重新被重视的替代范式:在推理阶段允许一小部分参数(所谓 fast weights)持续更新,通过最小化一个自监督目标把 context 信息”压缩”进这些权重中,相当于给模型一个”在线可塑性”状态。

但在 LLM 生态下 TTT 的潜力尚未释放,原因有三:

  1. 架构不兼容:现有 TTT 方法大多作为替代 attention 的新层出现(如 TTT-Linear / TTT-MLP / LaCT 等),必须从头预训练,完全不兼容 LLaMA、Qwen 这种已经预训练好的主流架构。
  2. 计算效率瓶颈:canonical TTT 是 per-token 序列更新,难以并行化;chunk-wise 加速方案(Titans、LaCT 等)为了保留 attention 替代能力不得不使用小 chunk(C=64~256),仍然浪费大量 GPU 并行度。
  3. 学习目标错位:常见 TTT 目标是把 value 设成 自身(reconstruction target),与 LLM 的 Next-Token-Prediction 目标缺乏显式对齐,理论上无益于提高正确 token logit。

作者想要一个真正可直接用于已有大模型的 TTT 变体,即不动 attention、不引入新层、不需要从头训练,同时在效率与目标上都针对 LLM 做出针对性设计。

2. Idea (核心思想)

作者提出 In-Place Test-Time Training (In-Place TTT),核心思想三点:

  1. 复用 MLP 的 down projection 作为 fast weights——不引入任何新层。将 gated MLP 中的 W_down 视作”动态可调”的一部分,在推理阶段对其做 online 更新;W_gate / W_up 保持 frozen。这样整体 Transformer 结构完全不变,“drop-in” 即可接入 Qwen3、LLaMA3 等预训练模型。
  2. chunk-wise 大 chunk 并行更新——由于 TTT 作用于 MLP 而非 attention(后者仍然做 token mixing),可以使用 C=512 或 1024 这种较大的 chunk size,在保证信息粒度的同时充分发挥 GPU 并行度。
  3. LM-Aligned Objective——把 TTT 的 target 从”当前 token”改为”未来若干 token 的局部组合”:。并以 (Frobenius inner product)作为损失,使 closed-form 的权重更新直接指向 NTP 目标。论文用 induction head 设定下的 theorem 证明了这一目标能在一步更新下把 correct next token 的 logit 显著拉高。

凭借这三点,In-Place TTT 成为一个天然兼容 context parallelism(CP)、且保持严格因果的 drop-in 长上下文增强模块;在 Qwen3-4B-Base + 35B tokens 的 continual training 下,RULER 128k 精度从 74.8 → 77.0,并外推到 256k。

3. Method (方法)

3.1 总体框架

Figure 1 解读:整张 pipeline 图的上半路径是 attention(完全保留原状,不参与 TTT),下半路径是 MLP block,此处被替换为 “MLP with In-Place TTT”。输入序列被切成 chunk(右侧 “Split into chunks”),每个 chunk 经历”Apply 再 Update”的两步循环:

  • Apply:用当前 fast weights 处理 chunk 的中间激活 ,得到该 chunk 的输出;
  • Update:基于 做一次梯度步,得到 供下一 chunk 使用。

其中 value target Conv1D & Projection 从 token embedding 中前瞻计算(这是 LM-Aligned Objective 的核心)。在文档边界处 会被 reset 回预训练值,避免跨文档信息泄漏;ΔW^{(i)} 之间通过 prefix-sum 做并行累加,从而支持 context parallelism。

3.2 复用 MLP 作为 fast weights

Figure 1 解读(放大 MLP 子图):图中”MLP with In-Place TTT”模块标出三类权重:frozen 的 W_gateW_up(绿色),以及 dynamic 的 W_down(橙色箭头),体现了”in-place 替换现有 MLP 的 down projection”这一核心设计——没有引入新模块、没有改变总参数量。

Gated MLP 的标准形式:

作者把 视作 frozen slow weights(承载预训练通用知识),把 视作 fast weights(承载当前 context 的瞬态信息)。选择 的两个原因:(1) 它是一个大的矩阵(intermediate_size × hidden_size),足以充当高容量的动态 memory;(2) Geva et al. (2020) 揭示 Transformer FFN 本身就在扮演 key-value memory 角色,将其做 in-place 更新在语义上是自然延伸。

Chunk-wise Apply-then-Update。设中间激活 ,value target 与输出 。把 按 chunk size 切成 个 chunk,令 ,对每个 chunk

  1. Apply
  2. Update 即下文 §3.3 / 代码 ttt_lr

注意:由于 attention 仍然做 token mixing,In-Place TTT 不再需要每 token 更新才能维持 causality,从而允许 C 大到 1024,这是相比 LaCT、Titans 等 TTT-replaces-attention 方案的关键优势。

3.3 LM-Aligned 目标与 closed-form update

Figure 1 解读(放大 LM-Aligned 子图)Conv1D & Projection 模块从 raw token embedding 中计算 value target ,其中 Conv1D 做局部 token 的前瞻组合(causal padding),Projection()把组合结果投影到 空间。这一分支是 In-Place TTT 与常规 reconstruction-TTT 的关键差异。

Target 的设计。常规 TTT 使用 (即当前 token 的投影),这让 “记忆自己”,并不能直接贡献 next-token logit。作者把 target 改为:

其中 是 token embedding, 的 kernel 控制未来若干 token 的局部组合(causal padding 保证 chunk 的 target 不看 chunk 以后的数据), 是可学习的投影。极端情况下把 Conv 核设为 [0, 1, 0, …, 0]、 即恢复 Next-Token Target;一般情况下近似一个 Multi-Token Prediction(MTP)式的富预测信号。

Loss function:使用 Frobenius inner product(负相似度),

由此导出 closed-form 的梯度更新(论文 Eq. (1),用 表示 learning rate,与论文 Theorem 1 记号一致;代码中对应的超参名 ttt_lr 默认值 0.3):

这种形式的更新无需任何额外 Adam/SGD 状态,非常轻量。后续伪代码中的变量 ttt_lr 与此处的 对应。

Induction head 下的 logit 分析(Theorem 1 简述)。考虑 出现在位置 ,位置 时查询 ,模型应预测 。把前文 chunk 累加的 fast weight delta 写成 ,于是 query 位置 的 logit 变化是

  • LM-Aligned target:。在 embedding 近似正交()、key-query 对齐()两条标准假设下:
  • Reconstruction target: 只能保证 ,无法显著提升 correct token 的 logit。

这给出了 LM-Aligned Objective 有效的清晰理论动机。

3.4 Context-Parallel 实现(chunk-wise 并行扫描)

Figure 1 解读(放大 CP / prefix-sum 子图):在”Apply & Update”循环右侧画出 的箭头,由 prefix-sum 聚合得到每个 chunk 应用前的 effective 。这一 prefix-sum 结构是 In-Place TTT 可以在 Ulysses / Context Parallelism 下保持严格因果的关键。

作者指出更新规则 具有 associativity,因此可以做 prefix sum 并行化:

  1. 对所有 chunk 并行计算
  2. 做一次 prefix sum 得到累积
  3. 并行计算每个 chunk 的 effective weight ,以及输出

因果性的两个细节:(1) 在生成 value target 时 Conv1D 使用 causal padding,保证 chunk 不含 chunk 之后的信息;(2) 在 document boundary 处把 reset 到预训练初值,避免跨样本信息泄漏。这两点加上 prefix-sum 等价于”严格串行 chunk 更新”,但吞吐上大幅提升。

仓库中的实现用单个 torch.cumsum(dim=chunk_dim) 完成 prefix sum:先把 沿 chunk 轴堆叠为 (B, chunk_num, d_ff, d_model),并在最前端拼接 自身作为 “chunk 0 的 effective weight”;cumsum 之后每个位置就是

3.5 Python/PyTorch 伪代码(关键组件)

以下伪代码对应源仓库 hf_models/hf_qwen3/modeling_qwen3.py(训练 / continual-pretrain 路径)与 inference_model/hf_qwen3/modeling_qwen3.py(推理路径)的 Qwen3MLP

组件 1:TTT-enabled MLP 初始化与 chunk padding

import torch
from einops import rearrange
from transformers.activations import ACT2FN
 
 
class Qwen3MLPTTT(torch.nn.Module):
    """Drop-in TTT-enabled gated MLP (Sec. 3.1).
 
    Mirrors Qwen3MLP.__init__ and Qwen3MLP.padding in hf_models/hf_qwen3/modeling_qwen3.py.
    """
 
    def __init__(self, config, layer_idx=None):
        super().__init__()
        H, F = config.hidden_size, config.intermediate_size
        self.gate_proj = torch.nn.Linear(H, F, bias=False)
        self.up_proj = torch.nn.Linear(H, F, bias=False)
        self.down_proj = torch.nn.Linear(F, H, bias=False)
        self.act_fn = ACT2FN[config.hidden_act]  # e.g. SiLU for Qwen3
        self.layer_idx = -1 if layer_idx is None else layer_idx
 
        if getattr(config, "ttt_mode", False) and self.layer_idx in getattr(
            config, "ttt_layers", []
        ):
            self.ttt_chunk = getattr(config, "ttt_chunk", 8192)
            self.ttt_lr = getattr(config, "ttt_lr", 0.3)
            self.ttt_proj = (
                torch.nn.Linear(H, H, bias=False)
                if getattr(config, "ttt_proj", True) else None
            )
            self.ttt_conv = torch.nn.Conv1d(
                H, H, kernel_size=5, padding=2, groups=H, bias=False,
            )
 
    def padding(self, x):
        """Pad seq-len to a multiple of ttt_chunk, then reshape into chunks."""
        if not hasattr(self, "ttt_chunk"):
            return x
        if x.shape[1] % self.ttt_chunk != 0:
            pad = torch.zeros(
                x.shape[0], self.ttt_chunk - x.shape[1] % self.ttt_chunk, x.shape[2],
                device=x.device, dtype=x.dtype,
            )
            x = torch.cat([x, pad], dim=1)
        return rearrange(x, "b (t c) d -> b t c d", c=self.ttt_chunk)

组件 2:LM-Aligned target 计算(Conv1D + 可选 Projection)

def compute_lm_aligned_target(self, t_chunks: torch.Tensor) -> torch.Tensor:
    """Given chunked hidden states t_chunks (B, T, C, H), return target V_hat with the
    same shape. Mirrors the Conv1D section of Qwen3MLP.forward.
 
    - Conv1d is depthwise (groups=H) with kernel=5, padding=2: local look-ahead.
    - Optional linear projection ttt_proj of shape (H, H).
    - When used inside the cumsum path below, causality is enforced by slicing
      [:, :-1] of (h, t) before contraction, so chunk i only updates W for chunk i+1.
    """
    bs, chunk_num, chunk_size, _ = t_chunks.shape
    t_conv = (
        self.ttt_conv(t_chunks.transpose(-1, -2).reshape(bs * chunk_num, -1, chunk_size))
        .transpose(-1, -2)
        .reshape(bs, chunk_num, chunk_size, -1)
    )
    return t_conv  # ttt_proj applied inside the einsum in 组件 3

组件 3:训练 forward —— chunk-wise 并行扫描(prefix-sum 形式的 CP 实现)

from einops import repeat
from opt_einsum import contract
 
 
def Qwen3MLPTTT_forward(self, x: torch.Tensor, t=None):
    """Training path. Mirrors Qwen3MLP.forward in hf_models/hf_qwen3/modeling_qwen3.py.
 
    Uses cumsum over ΔW chunks to realize the prefix-sum CP formulation in §3.4:
        W^{(i-1)} = W_down + ttt_lr * sum_{j<i} ΔW^{(j)}.
    """
    h = self.act_fn(self.gate_proj(x)) * self.up_proj(x)
    if t is None or not hasattr(self, "ttt_conv"):
        return self.down_proj(h)
 
    t_chunks = self.padding(t)                  # (B, T, C, H)
    h_chunks = self.padding(h)                  # (B, T, C, d_ff)
    bs, chunk_num, chunk_size, _ = t_chunks.shape
 
    t_conv = self.compute_lm_aligned_target(t_chunks)
 
    # ΔW^{(j)} for j = 0 .. T-2  (last chunk never "writes back").
    if self.ttt_proj is not None:
        d_down_proj = contract(
            "b t c h, b t c d, d e -> b t e h",
            h_chunks[:, :-1], t_conv[:, :-1], self.ttt_proj.weight,
        )
    else:
        d_down_proj = contract(
            "b t c h, b t c d -> b t d h",
            h_chunks[:, :-1], t_conv[:, :-1],
        )
 
    # Prepend W_down as the "W^{(0)}" summand so cumsum gives W^{(i-1)} at index i.
    d_down_proj = torch.cat(
        [
            repeat(self.down_proj.weight, "d h -> b 1 d h", b=bs),
            d_down_proj * self.ttt_lr,
        ],
        dim=1,
    )
    d_down_proj_sum = d_down_proj.cumsum(dim=1)     # prefix-sum along chunk axis
 
    down_proj = contract("b t d h, b t c h -> b t c d", d_down_proj_sum, h_chunks)
    return rearrange(down_proj, "b t c d -> b (t c) d")[:, : x.shape[1], :]

组件 4:推理 forward —— Apply-then-Update(顺序更新 fast weight,跨 chunk 携带)

def Qwen3MLPTTT_inference(self, x: torch.Tensor, t=None, past_w=None):
    """Inference path. Mirrors Qwen3MLP.forward in inference_model/hf_qwen3/modeling_qwen3.py.
 
    Returns (output, current_w). The caller is expected to cache current_w across calls
    (e.g. inside TTTDynamicCache) so the fast weight persists over the generation trajectory.
    """
    h = self.act_fn(self.gate_proj(x)) * self.up_proj(x)
    if not hasattr(self, "ttt_conv"):
        return self.down_proj(h)
 
    present_w = self.down_proj.weight.clone() if past_w is None else past_w
    if t is None:
        return torch.nn.functional.linear(h, present_w), present_w
 
    bs, seq_len, _ = x.shape
    if seq_len < self.ttt_chunk:
        return torch.nn.functional.linear(h, present_w), present_w
 
    t_chunks = self.padding(t)
    h_chunks = self.padding(h)
    bs, chunk_num, chunk_size, _ = t_chunks.shape
    t_conv = self.compute_lm_aligned_target(t_chunks)
 
    cur_w = present_w
    y = torch.zeros_like(t_conv)
    # NOTE: inference path in inference_model/hf_qwen3/modeling_qwen3.py only
    # exercises the bs=1 streaming case (KV cache mirrors single-sequence eval),
    # hence the indexing on [0, i] below. For B>1, wrap the inner loop in a
    # for-bsi-in-range(bs) and keep an independent cur_w per batch element.
    for i in range(chunk_num):
        cur_h = h_chunks[0, i]                          # (C, d_ff)
        cur_t = t_conv[0, i]                            # (C, H)
        y[0, i] = contract("d h, c h -> c d", cur_w, cur_h)   # Apply
        # Update: skip ΔW write-back for the last (possibly partial) chunk.
        if seq_len % self.ttt_chunk == 0 or i != chunk_num - 1:
            if self.ttt_proj is not None:
                dw = contract(
                    "c h, c d, d e -> e h",
                    cur_h, cur_t, self.ttt_proj.weight,
                ) * self.ttt_lr
            else:
                dw = contract("c h, c d -> d h", cur_h, cur_t) * self.ttt_lr
            cur_w = cur_w + dw
 
    out = rearrange(y, "b t c d -> b (t c) d")[:, :seq_len, :]
    return out, cur_w

组件 5:Custom weight init for TTT submodules(与 Qwen3PreTrainedModel._init_weights 的 TTT 分支严格对齐)

import torch
import torch.nn as nn
import torch.distributed as dist
 
 
def init_ttt_weights(module: nn.Module, initializer_range: float = 0.02):
    """Mirrors Qwen3PreTrainedModel._init_weights (TTT branch) in
    hf_models/hf_qwen3/modeling_qwen3.py. Applies to:
      - ttt_proj  (nn.Linear, square matrix H x H)  -> zero everywhere, random-normal on diagonal
      - ttt_conv  (nn.Conv1d, kernel=5, groups=H)   -> zero init
    """
    std = initializer_range
 
    if isinstance(module, nn.Linear):
        if module.weight.device.type == "meta":
            return
        # Only square matrices (i.e. ttt_proj, H x H) receive diagonal init; other Linear
        # layers (gate_proj / up_proj / down_proj) are loaded from the pretrained checkpoint.
        if module.weight.shape[0] == module.weight.shape[1]:
            diag_size = module.weight.shape[0]
            weight_data = module.weight.data
 
            if hasattr(weight_data, "_local_tensor"):
                # FSDP2 DTensor: operate on local shard; assume row-sharded.
                local = weight_data._local_tensor
                local.zero_()
                local_rows = local.shape[0]
                num_cols = local.shape[1]
                rank = dist.get_rank()
                start_row = rank * local_rows
 
                g = torch.Generator(device=local.device).manual_seed(42)
                diag_values = torch.randn(
                    diag_size, generator=g, device=local.device, dtype=local.dtype,
                ) * std
 
                local_idx = torch.arange(local_rows, device=local.device)
                global_cols = start_row + local_idx
                mask = global_cols < num_cols
                local_idx = local_idx[mask]
                global_cols = global_cols[mask]
                if len(local_idx) > 0:
                    local[local_idx, global_cols] = diag_values[global_cols]
            else:
                weight_data.zero_()
                diag_values = torch.randn(
                    diag_size, device=weight_data.device, dtype=weight_data.dtype,
                ) * std
                idx = torch.arange(diag_size, device=weight_data.device)
                weight_data[idx, idx] = diag_values
        if module.bias is not None:
            nn.init.zeros_(module.bias)
 
    elif isinstance(module, nn.Conv1d):
        # ttt_conv: zero init so the model starts as vanilla MLP (ΔW ≈ 0 initially).
        module.weight.data.zero_()
        if module.bias is not None:
            module.bias.data.zero_()

关键点:ttt_proj 的对角线初始化为 )而非恒等矩阵;非对角线位置严格清零;ttt_conv 完全置零。实现上对 DTensor(FSDP2)做了分片感知:local shard 只写属于自己那部分行的对角元素,全体 rank 共享固定种子 42 以保证跨 rank 一致。ttt_conv 置零可保证训练早期 ,TTT 分支从 vanilla MLP 出发逐步学习;ttt_proj 对角线的小幅随机扰动避免了退化解。

Code-to-paper mapping table

Paper ConceptSource FileKey Class / Functiongithub_ref
In-Place TTT 主体(§3.1,训练 / CP 路径)hf_models/hf_qwen3/modeling_qwen3.pyQwen3MLP.__init__Qwen3MLP.paddingQwen3MLP.forward(x, t)ByteDance-Seed/In-Place-TTT@be232482
In-Place TTT 推理(Apply-then-Update,§3.1)inference_model/hf_qwen3/modeling_qwen3.pyQwen3MLP.forward(x, t, past_w)(返回 (output, current_w),由 TTTDynamicCache 跨 step 缓存 fast weight)ByteDance-Seed/In-Place-TTT@be232482
LM-Aligned target = Conv1D + Projection(§3.3)hf_models/hf_qwen3/modeling_qwen3.pyself.ttt_conv(depthwise Conv1d, kernel_size=5, padding=2, groups=H, bias=False)+ self.ttt_proj(Linear H→H,可通过 config.ttt_proj=False 关闭)ByteDance-Seed/In-Place-TTT@be232482
Chunk-wise prefix-sum parallel scan(§3.4)hf_models/hf_qwen3/modeling_qwen3.pyQwen3MLP.forward 中的 d_down_proj.cumsum(dim=1) 段,对应 的 prefix sumByteDance-Seed/In-Place-TTT@be232482
TTT 子模块权重初始化(对角随机 + 零 Conv,DTensor-aware)hf_models/hf_qwen3/modeling_qwen3.pyQwen3PreTrainedModel._init_weightsnn.Linear square-matrix 分支:local-shard zero + 对角 randn(..) * 0.02,固定 seed=42;nn.Conv1d 分支:全零)ByteDance-Seed/In-Place-TTT@be232482
Decoder layer(默认把 hidden_states 同时作为 target)hf_models/hf_qwen3/modeling_qwen3.pyQwen3DecoderLayer.forwardif target_states is None and self.is_ttt_layer: target_states = hidden_states; hidden_states = self.mlp(hidden_states, t=target_states)ByteDance-Seed/In-Place-TTT@be232482
TTT 作用于哪些 layer 的配置hf_models/hf_qwen3/configuration_qwen3.pyttt_modettt_layersttt_chunkttt_lrttt_projByteDance-Seed/In-Place-TTT@be232482
Qwen3-4B-Base 的 continual pretrain 配置configs/pretrain/qwen3_longct.yamlttt_layers=[0,6,12,18,24,30,36], ttt_mode=true, ttt_proj=true, ttt_lr=3, ttt_chunk=4096ByteDance-Seed/In-Place-TTT@be232482
LLaMA-3.1-8B 同步版本configs/pretrain/llama3_longct.yamlhf_models/hf_llama/*LlamaMLP(与 Qwen3MLP 对称实现)ByteDance-Seed/In-Place-TTT@be232482
Train entrytasks/train_torch.pytrain.shByteDance-Seed/In-Place-TTT@be232482
Infer entrytasks/infer.pyeval.shByteDance-Seed/In-Place-TTT@be232482
RULER 4k~256k benchmark 配置eval_config/ruler_*.pyByteDance-Seed/In-Place-TTT@be232482
DCP → HF checkpoint 合并scripts/merge_dcp_to_hf.pyByteDance-Seed/In-Place-TTT@be232482

4. Experimental Setup (实验设置)

作者围绕三个研究问题设计实验(Sec. 4):

Q1:In-Place TTT 作为 drop-in 增强预训练 LLM 的效果? Q2:从零训练时,与现有 TTT / 线性 / SWA 方案相比的优劣? Q3:关键设计(state size、chunk size、LM-aligned objective)各自的贡献?

Q1:Drop-in enhancement on pretrained LLMs

  • 主实验:Qwen3-4B-Base(原生 32k 上下文);
  • 两阶段 continual training:Stage 1 约 20B tokens @ 32k context,Stage 2 约 15B tokens @ 128k context;
  • 用 YaRN 对 RoPE 做 scale 以支持长上下文;
  • Baseline 与 In-Place TTT 使用完全相同的 curriculum、只切换是否启用 TTT;
  • 评测:RULER 4k~256k(OpenCompass),256k 专门检验 extrapolation;
  • 扩展实验:同协议应用到 LLaMA-3.1-8B 和 Qwen3-14B-Base(额外加入 64k + YaRN 组)。

Q2:From-scratch pre-training

  • 500M / 1.5B 比较对象:Transformer + SWA、GLA、DeltaNet、LaCT(SWA 底)、In-Place TTT(SWA 底);
  • 32k sequence length、TogetherAI 数据;
  • 主指标:Sliding Window Perplexity(固定末尾 block、延长前缀)on Pile + Proof-Pile-2;
  • 4B 比较对象:Full-Attn vs. Full-Attn + In-Place TTT,SWA vs. SWA + In-Place TTT,120B tokens @ 8k 预训练;
  • 评测:HellaSwag / ARC-E / ARC-C / MMLU / PIQA 常识推理 + RULER 4k/8k/16k。

Q3:Ablation

  • 模型:1.7B、RULER benchmark;
  • 扫描维度:state size(TTT 层数,0.5× / 1× / 4× baseline state size)、chunk size(256 / 512 / 1024 / 2048)、LM-Aligned 目标(w Conv, Proj / w/o Conv / w/o Proj / w/o Conv, Proj)。

默认超参

  • 论文 §4.3 推荐配置:ttt_chunk=1024(兼顾效率与精度,C=512 / C=1024 同档最优),代码层面 getattr(config, "ttt_chunk", 8192) 提供保守 fallback;
  • configs/pretrain/qwen3_longct.yaml(Qwen3-4B-Base 实际使用):ttt_layers=[0, 6, 12, 18, 24, 30, 36]ttt_mode=truettt_proj=truettt_lr=3ttt_chunk=4096
  • ttt_proj:square Linear H→H,按 Qwen3PreTrainedModel._init_weights 对角随机初始化 ),非对角清零;并非恒等矩阵;
  • ttt_conv:groups=H 的 5-tap depthwise Conv1d(kernel_size=5, padding=2),权重 zero-init,使训练初期
  • continual training 使用 FSDP2 + Context Parallelism + flash attention。

5. Experimental Results (实验结果)

5.1 RULER 主表(Qwen3-4B-Base)

Model4k8k16k32k64k128k256k
Mistral-7B93.691.287.275.449.013.8
GLM3-6B87.883.478.669.956.042.0
Phi3-medium-14B93.393.291.186.878.646.1
Llama3-8B92.890.385.779.976.369.5
Qwen3-4B (Instruct)95.193.691.087.877.866.0
Baseline (Qwen3-4B-Base + continual training)96.694.192.188.774.374.841.7
+ In-Place TTT96.195.692.789.378.777.043.9

核心现象:

  • 在 4k 上略低 0.5pp(有限噪声),在 8k 起的所有长度上稳定胜出 baseline;
  • 64k / 128k / 256k 分别 +4.4 / +2.2 / +2.2 pp,即越长越有效;
  • 256k 是 extrapolation 区(训练最长到 128k),In-Place TTT 仍保持增益,表明其确实在”压缩 context 信息”而不是简单记忆位置模式。

扩展:LLaMA-3.1-8B / Qwen3-14B-Base(Tab. 2)

Base ModelMethod4k8k16k32k64k64k+YaRN
LLaMA-3.1-8BBaseline93.992.192.591.181.6
LLaMA-3.1-8B+ In-Place TTT94.493.093.391.783.7
Qwen3-14B-BaseBaseline96.895.094.690.767.981.3
Qwen3-14B-Base+ In-Place TTT97.295.795.291.270.682.5
  • 不同模型家族(LLaMA vs. Qwen3)、不同规模(4B/8B/14B)均保持一致正收益;
  • 与 RoPE-scale 技术(YaRN)正交,可叠加使用。

5.2 From-scratch 比较:500M / 1.5B perplexity

Figure 2 解读:两张图分别展示 500M / 1.5B 模型在 Pile 验证集上的 Sliding Window Perplexity 曲线。x 轴 context length(2k~32k),y 轴 PPL。

  • 500M:SWA、GLA、DeltaNet 的 PPL 在 8k 之后都明显回升,LaCT 稍好但仍呈上翘趋势;In-Place TTT 唯一保持单调下降,说明它是唯一能有效压缩长上下文信息的架构。
  • 1.5B:所有方案绝对 PPL 更低;In-Place TTT 仍然稳定领先约 0.1~0.3 PPL。

5.3 From-scratch 4B 表(Tab. 3)

ModelArchitectureHellaSwagARC-EARC-CMMLUPIQARULER-4kRULER-8kRULER-16k
BaselineFull Attn.55.6764.5233.1936.4372.6345.7738.096.58
BaselineSWA54.9264.1832.8536.0672.5814.779.915.07
I.P. TTTFull Attn.55.8564.9832.3437.4273.2949.9843.8219.99
I.P. TTTSWA55.2464.6033.7036.4872.0328.3326.807.57

关键结论:

  • 常识推理基本全面提升(HellaSwag / ARC-E / MMLU / PIQA 均更高);
  • Long-context 提升最为显著:RULER-16k 下 Full-Attn+TTT 从 6.58 → 19.99(3× gain),SWA+TTT 从 9.91 → 26.80(2.7× gain),证明 In-Place TTT 的核心价值在 long-context 压缩而非 short context。

5.4 Ablation:state size / chunk size / objective

State size(Figure 3(a)):按 TTT-enabled layer 数控制 state 大小(0.5× / 1× / 4×)。结果 4× > 1× > 0.5×,即更多 TTT 层→ 更大 fast-weight state → 更强 long-context。这说明复用 MLP state 的方向是对的,还有进一步 scale 的空间。

Chunk size(Figure 3(b)):C ∈ {256, 512, 1024, 2048}。结果 512 / 1024 最优;256 过细、2048 过粗。验证”In-Place TTT 天然支持大 chunk”(这是它与 LaCT 等 TTT-as-attention 方案的关键差异)。C=1024 在精度相近前提下吞吐最高,被选作默认。

LM-Aligned objective(Figure 3(c)):

VariantRULER-4kRULER-8kRULER-16k
w Conv, Proj(默认)最高最高最高
w/o Conv(仅 Proj)明显下降显著下降显著下降
w/o Proj(仅 Conv)小幅下降小幅下降中等下降
w/o Conv, Proj(reconstruction)最低最低最低

Conv 在长上下文上决定性地重要(让 target 包含未来 token 信息),Projection 在短上下文上更关键(灵活表达映射)。二者都去掉即退化成 reconstruction target,全面变差。这与 Theorem 1 的预测完全吻合。

5.5 效率

Figure 4(a-b) 解读:分别是 SWA 底与 Full-Attn 底下的 prefill throughput(K tokens/s,越高越好)。x 轴 context length(8k / 32k / 128k)。Baseline vs. TTT 曲线几乎重合,最大差距约 5~8%,说明 In-Place TTT 对推理吞吐几乎不增加开销。

Figure 4(c-d) 解读Peak memory (GB)。TTT 版本的内存占用与 Baseline 基本一致——因为 fast weight 与 MLP 原有权重共享 slot,额外只存一个 clone、Conv1D 与 Proj 几兆参数。这再次强化”drop-in”承诺:不需要额外硬件或大幅重训。

5.6 Sanity-check:TTT results 细节

Figure 扩展解读:对 500M / 1.5B 模型做 Sliding Window PPL(y 轴)随 context length(x 轴)变化的放大图,可看到 In-Place TTT 的曲线相对 baseline 系列在 8k→32k 区间有单调分离的下降斜率,而 SWA/GLA/DeltaNet/LaCT 会在某个长度后停止下降甚至回升。这一”长度越长收益越大”的趋势与 Table 1 的 4B 结果一致。

5.7 Discussion & 局限

  • 结论:In-Place TTT 提供了一条不破坏预训练权重兼容 context parallelism具有 LM-aligned 理论保证的 fast-weights 实现路径。它让 Qwen3-4B 这种既有模型以 35B tokens 的 continual training 就能把 128k RULER 提升约 +2.2 pp,是目前对现有 LLM 最友好的 TTT 方案之一。
  • 推广:该框架可以自然地叠加到 Full-Attn、SWA、GLA、DeltaNet 等多种 backbone,以及 YaRN、NTK-scaling 等 RoPE 扩展方法。
  • 局限
    • LM-aligned objective 的理论证明在 induction-head 单层简化场景下给出,未覆盖多层复杂 circuit;
    • 当前只实例化了 down_proj 的更新,其他 MLP 项(gate/up)或 attention 投影能否同样重新用于 TTT 未展开;
    • Chunk size、state size、lr 等仍然需要按模型规模调节,缺少 auto-tune;
    • 依赖 YaRN 做 RoPE scale;极端超长(>256k)行为有待继续 scale 数据与训练。