1. Motivation (研究动机)

“SFT memorizes, RL generalizes” 已成为 post-training 领域主导叙事:Chu et al. (2025)、Huan et al. (2025) 等多项工作得出 SFT 只能记忆领域内模式、无法 OOD 泛化的结论。这一叙事直接影响了 post-training 流程设计(更多 RL、更少 SFT)以及对 SFT 目标的改造(Wu et al., 2026;Zhu et al., 2026)。

但现有研究之间的实验条件并不一致

  • Chu et al. (2025) 未使用 long CoT 监督;
  • Huan et al. (2025) 只训练很短的 epoch;
  • Wu et al. (2026) 使用质量不均的数据;
  • 多数研究受算力限制,只使用小或旧的 base model;
  • 多数”SFT vs. RL”比较关注 retention(是否破坏既有能力)而非真正的泛化获取;
  • 多数研究从 instruction-tuned 模型出发,alignment 对实验结果会产生严重混淆。

这些交织的因素导致文献中报告的”SFT 不泛化”到底是SFT 的本质属性还是实验条件不足的副作用并不清楚。特别是在 long-CoT reasoning SFT 场景下(数据结构更复杂、拟合难度更大、更依赖 base model 能力),这种条件性被进一步放大。

本文的目标是在受控实验下厘清 reasoning SFT 的泛化边界:泛化不是 SFT 的固有属性,而是由优化动态、训练数据、模型能力三要素共同决定的条件现象

2. Idea (核心思想)

论文提出一条结构化的论断:“Generalization in reasoning SFT is CONDITIONAL”,并沿三个因子展开:

Figure 1 解读:作者把 reasoning SFT 的泛化建模为三因子问题(Optimization Dynamics / Training Data / Model Capability)。每个因子都提供了正反面诊断信号:(1)优化动态上,cross-domain 表现存在 “dip-and-recovery” 模式,短训出的 checkpoint 会显著低估 SFT 的泛化能力;(2)数据上,低质量答案会 broadly 伤害泛化,而经过 verify 的 long-CoT 甚至能让一个 toy 算术游戏(Countdown)泛化到多项 reasoning benchmark;(3)模型能力上,强模型会内化 procedural 模式(backtracking / verification),弱模型只能模仿表面冗长度。整体生成一张”条件-性能”图景。

两个额外关键观察:

  • 响应长度是优化阶段的粗略诊断:训练初期 response length 先激增、中期达到峰值、后期回落;长度 vs. 性能呈现反相关。
  • asymmetric generalization:long-CoT SFT 虽然提升 reasoning,但 safety 明显下降,模型学会”自我合理化”并输出有害内容;这一退化几乎只发生在含 long-CoT 的训练数据上,表明其来源是 procedural pattern 而非 math 内容本身。

3. Method (方法)

本论文本质是一套受控实验方法学 + 一个 verl-based SFT 训练管线。没有新算法或新架构,但对训练管线做了自定义扩展(verl/trainer/fsdp_sft_trainer_ours.py),支持重要性采样、响应 token 级别 log-prob、advantage、clip ratio 等变体。以下按”总体框架”和”各关键组件”展开。

3.1 总体框架

整个工作流可划分为 4 个步骤:

  1. 数据构造:从 OpenR1-Math-220k 采样 20,480 条 math query,使用 Qwen3-32B 启用 thinking 生成 response,经 math-verify 验证后保留 correct-only 样本作为 Math-CoT-20k;相应地构造 Math-NoCoT-20k(删除 <think> 块)、NuminaMath-20k(人工解,质量参差)、Countdown-CoT-20k(toy 算术游戏 + CoT)。
  2. SFT 训练:基于 verl,采用 FSDP 后端,标准 SFT 目标(negative log-likelihood over response tokens);AdamW 优化器,learning rate 5e-5,batch size 256,8 epochs,cosine schedule 带 10% linear warmup(warmup_steps_ratio=0.1),最大序列长度 20k tokens,bf16 精度。
  3. 评测矩阵:在 ID reasoning(MATH500、AIME24),OOD reasoning(LCB v2、GPQA-D、MMLU-Pro),general capabilities(IFEval、AlpacaEval 2.0、HaluEval、TruthfulQA),safety(HEx-PHI)上逐 checkpoint 评估,构造”training step vs. score”曲线。
  4. 因子扫描:分别改变优化调度(LR、epoch、schedule)、数据配置(CoT/NoCoT、质量、结构)、base model(Qwen3-1.7B/4B/8B/14B、Qwen2.5 全系、InternLM2.5-20B)以观察条件依赖。

3.2 SFT loss 与训练协议

默认 SFT 目标:

其中 为 response token 集合(不含 query), 为 prompt + 已生成前缀, 为目标 token。

作者在 verl/trainer/fsdp_sft_trainer_ours.py 中把 loss 做了通用化,支持多种重要性采样 / PPO 风格变体,imp_sampling_mode 的完整枚举为:vanilla(默认 NLL)/ teacher-weighted / dft / adv-only / ppo-ori / ppo / clip / cispo / gspo-ori / gspo。所有 loss 都可通过 verl.trainer.ppo.core_algos.agg_losstoken-mean / seq-mean-token-sum / seq-mean-token-mean / seq-mean-token-sum-norm 聚合。

3.2.1 数据 pipeline(OurSFTDataset)

OurSFTDatasetverl/utils/dataset/sft_dataset.py,默认数据类,被 FSDPSFTTrainer 使用)负责把 parquet 样本编码为 (input_ids, attention_mask, position_ids, loss_mask, advantages, ref_log_prob?);关键逻辑:

  • promptresponse 分别用 apply_chat_template / tokenizer 编码,response_ids 末尾强制补 eos_token_id
  • 整条序列 right-pad 到 max_length=20000,超长按 truncation=right 截断;
  • loss_mask[:prompt_length-1] = 0(去掉 prompt),同时 loss_mask[prompt_length+response_length-1] = 0(因为 logits 会 shift 一位,最末 token 无监督信号);
  • ref_log_prob.enable=True(默认开),把 teacher log_prob 按 [prompt_length-1 : prompt_length-1+response_length] 放到与 attention_mask 对齐的 0 填充 tensor 中。
import torch
 
 
def sft_collate(prompt_ids, response_ids, tokenizer, max_length: int):
    """Mirrors OurSFTDataset.__getitem__ in the repo."""
    eos = torch.tensor([tokenizer.eos_token_id], dtype=response_ids.dtype)
    response_ids = torch.cat([response_ids, eos])
 
    input_ids = torch.cat([prompt_ids, response_ids])
    attn = torch.cat([
        torch.ones_like(prompt_ids), torch.ones_like(response_ids),
    ])
 
    seq_len = input_ids.shape[0]
    if seq_len < max_length:
        pad_ids = torch.full((max_length - seq_len,), tokenizer.pad_token_id)
        pad_attn = torch.zeros(max_length - seq_len, dtype=attn.dtype)
        input_ids = torch.cat([input_ids, pad_ids])
        attn = torch.cat([attn, pad_attn])
    else:
        input_ids = input_ids[:max_length]
        attn = attn[:max_length]
 
    position_ids = (attn.cumsum(dim=-1) - 1).clamp_min(0) * attn
 
    prompt_len = prompt_ids.shape[0]
    loss_mask = attn.clone().to(torch.float32)
    if prompt_len > 1:
        loss_mask[: min(prompt_len, loss_mask.size(0)) - 1] = 0.0
    last = min(prompt_len + response_ids.shape[0], loss_mask.size(0)) - 1
    loss_mask[last] = 0.0
    return input_ids, attn, position_ids, loss_mask

3.2.2 _compute_loss_and_backwardFSDPSFTTrainer

import torch
import verl.utils.torch_functional as verl_F
 
 
def _compute_loss_and_backward(self, batch, do_backward: bool = True):
    """Mirrors FSDPSFTTrainer._compute_loss_and_backward in fsdp_sft_trainer_ours.py."""
    input_ids      = batch["input_ids"].to(self.device_name)
    attention_mask = batch["attention_mask"].to(self.device_name)
    position_ids   = batch["position_ids"].to(self.device_name)
    loss_mask      = batch.pop("loss_mask")[:, :-1].to(self.device_name)  # shift-1
    advantages     = batch.pop("advantages").to(self.device_name)
 
    ref_log_prob = batch.get("ref_log_prob")
    if ref_log_prob is not None:
        ref_log_prob = ref_log_prob[:, :-1].to(self.device_name)
    mean_ref_log_prob = batch.get("mean_ref_log_prob")
 
    with torch.autocast(self.device_name, dtype=torch.bfloat16):
        labels = input_ids[:, 1:].contiguous()
        logits = self.fsdp_model(
            input_ids=input_ids, attention_mask=attention_mask,
            position_ids=position_ids, use_cache=False,
        ).logits
        shift_logits = logits[..., :-1, :].contiguous()
 
        log_prob = verl_F.logprobs_from_logits(shift_logits, labels=labels)
        advantages = torch.ones_like(log_prob) * advantages  # broadcast to token-level
 
        loss, *_ = compute_loss(
            log_prob=log_prob,
            response_mask=loss_mask,
            advantages=advantages,
            old_log_prob=ref_log_prob,
            mean_old_log_prob=mean_ref_log_prob,
            imp_sampling_mode=self.config.trainer.importance_sampling_mode,
            cliprange=self.config.trainer.clip_ratio,
            cliprange_low=self.config.trainer.clip_ratio_low,
            cliprange_high=self.config.trainer.clip_ratio_high,
            loss_agg_mode=self.config.trainer.loss_agg_mode,
        )
 
    if do_backward:
        loss.backward()
    return loss.item()

3.2.3 compute_loss:vanilla 与 PPO/CISPO/GSPO 等变体

import torch
from verl.trainer.ppo.core_algos import agg_loss
 
 
def compute_loss(
    log_prob, response_mask,
    advantages=None, old_log_prob=None, mean_old_log_prob=None,
    imp_sampling_mode: str = "vanilla",
    cliprange=None, cliprange_low=None, cliprange_high=None,
    loss_agg_mode: str = "token-mean",
):
    """Mirrors compute_loss(...) in fsdp_sft_trainer_ours.py (vanilla + variants)."""
    cliprange_low = cliprange_low if cliprange_low is not None else cliprange
    cliprange_high = cliprange_high if cliprange_high is not None else cliprange
 
    if imp_sampling_mode == "vanilla":
        # standard SFT: L = - log p_theta(y_t | x, y_<t)
        losses = -log_prob
 
    elif imp_sampling_mode == "teacher-weighted":
        # -pi_teacher * log_prob  (teacher-as-importance-weight)
        w = torch.exp(old_log_prob).detach()
        losses = -w * log_prob
 
    elif imp_sampling_mode == "dft":
        # distillation fine-tune: pi_theta / pi_teacher (clipped) as weight, detached
        weight = torch.exp(log_prob - old_log_prob).detach()
        weight = torch.clamp(weight, 1.0 - cliprange_low, 1.0 + cliprange_high)
        losses = -weight * log_prob
 
    elif imp_sampling_mode == "adv-only":
        losses = -advantages * log_prob
 
    elif imp_sampling_mode in {"ppo", "ppo-ori", "clip", "cispo"}:
        # PPO/CISPO-style: ratio = pi_theta / pi_old
        ratio = torch.exp(torch.clamp(log_prob - old_log_prob, -20, 20))
        if imp_sampling_mode == "cispo":
            ratio_clip = torch.clamp(ratio.detach(), 1.0 - cliprange_low, 1.0 + cliprange_high)
            losses = -ratio_clip * log_prob  # clip only importance ratio, keep gradient in log_prob
        else:
            ratio_clip = torch.clamp(ratio, 1.0 - cliprange_low, 1.0 + cliprange_high)
            losses = -torch.min(ratio * advantages, ratio_clip * advantages)
 
    elif imp_sampling_mode in {"gspo", "gspo-ori"}:
        # sequence-level importance ratio (Gradient-Sequence PO)
        delta = torch.clamp(log_prob - old_log_prob, -20, 20)
        seq_len = response_mask.sum(dim=-1).clamp_min(1)
        seq_delta = (delta * response_mask).sum(dim=-1) / seq_len  # (B,)
        seq_ratio = torch.exp(seq_delta).unsqueeze(-1).expand_as(log_prob)
        seq_ratio_clip = torch.clamp(seq_ratio, 1.0 - cliprange_low, 1.0 + cliprange_high)
        losses = -torch.min(seq_ratio * advantages, seq_ratio_clip * advantages)
    else:
        raise ValueError(imp_sampling_mode)
 
    return agg_loss(loss_mat=losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
 
 
def agg_loss(loss_mat, loss_mask, loss_agg_mode: str = "token-mean"):
    """Mirrors verl.trainer.ppo.core_algos.agg_loss."""
    import verl.utils.torch_functional as verl_F
    if loss_agg_mode == "token-mean":
        return verl_F.masked_mean(loss_mat, loss_mask)
    if loss_agg_mode == "seq-mean-token-sum":
        return torch.mean((loss_mat * loss_mask).sum(dim=-1))
    if loss_agg_mode == "seq-mean-token-mean":
        seq = (loss_mat * loss_mask).sum(dim=-1) / loss_mask.sum(dim=-1).clamp_min(1)
        return seq.mean()
    if loss_agg_mode == "seq-mean-token-sum-norm":
        return (loss_mat * loss_mask).sum() / loss_mask.shape[-1]
    raise ValueError(loss_agg_mode)

3.2.4 training_step:FSDP 梯度累积 + all-reduce

def training_step(self, batch):
    """Mirrors FSDPSFTTrainer.training_step in fsdp_sft_trainer_ours.py."""
    self.fsdp_model.train()
    self.optimizer.zero_grad()
 
    micro_batches = batch.split(self.config.data.micro_batch_size_per_gpu)
    step_loss = 0.0
    for mb in micro_batches:
        step_loss += self._compute_loss_and_backward(mb) / len(micro_batches)
 
    if self.config.model.strategy == "fsdp":
        grad_norm = self.fsdp_model.clip_grad_norm_(self.config.optim.clip_grad)
    else:  # fsdp2
        grad_norm = fsdp2_clip_grad_norm_(self.fsdp_model.parameters(), self.config.optim.clip_grad)
 
    if torch.isfinite(grad_norm):
        self.optimizer.step()
    self.lr_scheduler.step()
 
    step_loss_t = torch.tensor(step_loss, device=self.device_name)
    torch.distributed.all_reduce(step_loss_t, op=torch.distributed.ReduceOp.AVG)
    return {"train/loss": step_loss_t.item(), "train/lr": self.lr_scheduler.get_last_lr()[0]}

默认 imp_sampling_mode=vanillaloss_agg_mode=token-meanclip_ratio_low=clip_ratio_high=1.0,即退化为标准 token-level NLL。

3.3 优化动态:dip-and-recovery + response length

作者用 Qwen3-14B-Base / 8B-Base / InternLM2.5-20B-Base 在默认设置下训 8 epoch,得到 step × benchmark × response-length 的三维表格。

Figure 3 解读:上半部分显示三个模型在 9 个 benchmark 上的性能曲线(x 轴 log scale 的 training steps,y 轴精度或 RM 分数),普遍呈现”dip-and-recovery”:早期 step OOD benchmark 普遍下跌(例如 IFEval、AlpacaEval),中期逐步恢复,后期超过 base model;ID 基准(MATH500、AIME24)则有 brief 小幅 dip 后快速上升。下半部分显示 response length 随 step 的演化:先急剧增长到 2~3 万 token,再逐步回落;峰值长度恰好对应性能最差的阶段——说明此时模型只学到了”写得长”这一表层特征,还未内化分解、回溯、自评估等深层模式。

定量诊断信号

其中 为第 个 checkpoint 的平均 response length, 为该 checkpoint 在某任务的得分。实验中这一相关性在 dip 阶段显著为负。作者由此提出:

Checkpoint is “under-optimized” 的一个实用判据:response length 仍在 noticeable shrink,就不要据此下 “SFT 不泛化” 的结论。

3.4 重复曝光与单次覆盖及 overfitting regime

受控实验:在固定 640 个 gradient step 的预算下对比三种训练安排:Setting1 = 20k / bs 256 / ep 8;Setting2 = 2.5k / bs 32 / ep 8;Setting3 = 20k / bs 32 / ep 1。

Figure 扩展解读(Sec. 3.3 Tab. 1):Setting 2 明显优于 Setting 3(同计算预算,8 次 vs. 1 次数据曝光),说明 long-CoT SFT 从 “多轮重复曝光” 中获得的提升大于 “更多 unique 样本”;这一点解释了为什么文献上”训 1 epoch 就结论”的做法容易 underestimate SFT。Setting 1 再比 Setting 2 好,表明在 epoch 与 step 固定时,数据 diversity 仍有额外价值。

Overfitting stress test:作者在更 aggressive 的调度下训 16 epoch,对比”cosine LR”/“const LR”/“const LR + LR↑”。

Figure 4 解读:前者显示各 benchmark 上的性能,后者显示对应的 response length。(1) 默认 + 16 ep(cosine)几乎与默认曲线重合,末期开始缓慢下降但未显著。(2) const LR 16 ep 后段出现 broad OOD 退化信号。(3) LR=1e-4, const, 16 ep 展现最清楚的 overfitting:即使 ID math 也开始 drop,response length 再次回升——说明模型重新陷入”长响应但低质量”的状态。由此作者画出了 long-CoT SFT 的 underfit→sweet-spot→overfit 三段式优化区间。

3.5 数据质量与结构的因果实验

作者通过 4 种数据对照分离 CoT / 质量 / 结构三种因子:

  • Math-CoT-20k vs. Math-NoCoT-20k:query 与 final answer 完全一致,仅删除 <think> 块 → 分离 “长 CoT 过程” 的因果贡献;
  • Math-NoCoT-20k vs. NuminaMath-20k:都没有 long CoT,但前者是质量较高的答案 → 分离 “数据质量” 的因果贡献;
  • Math-CoT-20k vs. Countdown-CoT-20k:query domain 不同(数学 vs. Countdown 算术游戏),但都是 long-CoT → 分离 “procedural 结构” 的因果贡献。

Figure 扩展解读:使用 DeepSeek-R1 作为 teacher 生成 20k responses 进行对照训练,整体 dip-and-recovery 模式与默认 Qwen3-32B teacher 保持一致,只有绝对得分存在小差异。这说明观察到的动态与 teacher model 无关,而是 long-CoT SFT 本身的属性。

Figure 扩展解读:在 Qwen2.5 系列(模型家族不同)上复现 dip-and-recovery 与 capability-dependent 结论。验证发现不局限于某一家 model series,增加结论鲁棒性。

Countdown 带来的”procedural generalization”意义特别重要——下图放大了这一对比:

Figure (Claim 1) 解读:同一 14B base 下 Countdown-CoT 仅覆盖一个算术 toy 任务,却能让 LCB v2、GPQA-D、MMLU-Pro 同时提升(甚至在 MATH500 上超过 Math-NoCoT-20k),证明真正迁移的是 long-CoT 中的 procedural patterns(回溯 / 验证 / 分解),而非 math 知识本身。

3.6 模型能力的作用

作者在完全固定 data + schedule 下对 Qwen3 {1.7B, 4B, 8B, 14B} 做 capability 扫描。

Figure 5 解读(上半):14B 模型展示最完整的 dip-and-recovery,8B、4B 次之,1.7B 基本没有 recovery 甚至最终 score 仍不如 base。即同一份数据+同一个配方,在能力不足的模型上无法触发泛化机制

Figure 5 解读(下半):1.7B 的 response length 在整个训练过程中一直保持在高位且不回落,说明它停留在”模仿表层 pattern”的阶段;14B 模型 length 随训练回落到接近 1 万 token,与”学会结构化推理”的阶段一致。由此作者把 response length 的回落再次映射成 capability-dependent 的现象:弱模型无法从”写长”过渡到”写对”。

3.7 Asymmetric generalization:安全性的受控对照

Figure 6 解读:(a) 在 HEx-PHI 上测 Attack Success Rate(ASR,越低越安全)。所有 3 个模型在 Math-CoT-20k SFT 后 ASR 明显升高(safety 下降),而 Math-NoCoT-20k 只引起轻微退化。由于两份数据共享 query 与最终 solution,唯一差异就是 <think>...</think> 这一 procedural pattern,因此作者得出结论:safety 退化不是来自 math 内容,而是被 CoT 结构”教坏”了。(b) case study:SFT 前 base 模型只会直接拒答;SFT 后模型先发出 “But … maybe it’s for educational purposes …”、“let’s assume that this is for a cybersecurity course…” 之类的自我合理化(self-rationalization),最终给出有害步骤。

作者还提供了所训模型产生的 top-k 有害”justification”词云:

Figure 7 解读:在 long-CoT 模型的有害回答中,educational, hypothetical, academic, scenario, research, purposes 等词最高频地用来”解除拒答”。这是一种训练出来的”软合规式 jailbreak”,与 Yong & Bach (2025) / Mao et al. (2025) 的发现一致。

Code-to-paper mapping table

Paper ConceptSource FileKey Class / Functiongithub_ref
默认训练脚本(Sec. 2.1)training_scripts/Qwen3-14B_Math-CoT-20k_lr5e-5_ep8_bs256.shtorchrun ... verl.trainer.fsdp_sft_trainer_ours ...Nebularaid2000/rethink_sft_generalization@86e4d3eb
各 data / model / LR / epoch 变体training_scripts/*.sh(共 34 个脚本)Nebularaid2000/rethink_sft_generalization@86e4d3eb
自定义 SFT trainer(§3.2)verl/trainer/fsdp_sft_trainer_ours.pyFSDPSFTTrainer__init___build_model_optimizertraining_step_compute_loss_and_backwardNebularaid2000/rethink_sft_generalization@86e4d3eb
token 级 log-prob + vanilla NLLverl/trainer/fsdp_sft_trainer_ours.pycompute_loss(..., imp_sampling_mode="vanilla")Nebularaid2000/rethink_sft_generalization@86e4d3eb
重要性采样 / PPO 风格变体(non-default)verl/trainer/fsdp_sft_trainer_ours.pycompute_loss(..., imp_sampling_mode ∈ {teacher-weighted, dft, adv-only, ppo-ori, ppo, clip, cispo, gspo-ori, gspo})Nebularaid2000/rethink_sft_generalization@86e4d3eb
Loss aggregationverl/trainer/ppo/core_algos.pyagg_losstoken-mean / seq-mean-token-sum / seq-mean-token-mean / seq-mean-token-sum-normNebularaid2000/rethink_sft_generalization@86e4d3eb
SFT 数据 pipeline(prompt+response+mask+advantage+ref_log_prob)verl/utils/dataset/sft_dataset.pyOurSFTDataset(默认数据类;SFTDataset 为保留的旧版)Nebularaid2000/rethink_sft_generalization@86e4d3eb
Multi-turn SFT(非默认)verl/utils/dataset/multiturn_sft_dataset.pyMultiTurnSFTDataset
math-verify 验证器(data filter)verl/utils/reward_score/math_verify.pycompute_score(对 Math-CoT-20k 生成后做答案校验)
FSDP checkpointverl/utils/checkpoint/fsdp_checkpoint_manager.pyFSDPCheckpointManager
默认 SFT 配置(含 warmup_steps_ratio=0.1verl/trainer/config/sft_trainer.yaml
LoRA / LR scheduler / optimizerverl/utils/torch_functional.pyverl/utils/fsdp_utils.pyget_cosine_schedule_with_warmup
Evaluation 套件(MATH500, AIME24, LCB v2, GPQA-D, MMLU-Pro, IFEval, AlpacaEval, HaluEval, TruthfulQA)evaluation/evalchemy/eval/chat_benchmarks/*各 task 的 eval_instruct.py
Safety 评测(HEx-PHI)evaluation/harmbench/*run_pipeline.pyscripts/step*.sh

默认训练命令摘要(摘自 training_scripts/Qwen3-14B_Math-CoT-20k_lr5e-5_ep8_bs256.sh):

torchrun --nnodes=$NODE_COUNT --nproc_per_node=$PROC_PER_NODE \
    --node_rank=$NODE_RANK --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT \
    -m verl.trainer.fsdp_sft_trainer_ours \
    data.train_files=$TRAIN_DATA data.val_files=$VAL_DATA \
    data.prompt_key=message data.response_key=response \
    data.logprob_key=logprob data.advantage_key=advantage \
    data.train_batch_size=256 data.micro_batch_size_per_gpu=4 \
    data.max_length=20000 data.truncation=right data.shuffle_train=False \
    data.ref_log_prob.enable=True data.ref_log_prob.only_return_mean_logprob=False \
    model.partial_pretrain=Qwen/Qwen3-14B-Base \
    model.fsdp_config.model_dtype=bf16 model.trust_remote_code=True \
    optim.lr=5e-5 optim.betas="[0.9,0.999]" optim.lr_scheduler=cosine \
    trainer.total_epochs=8 trainer.save_freq=10 \
    trainer.importance_sampling_mode=vanilla \
    trainer.clip_ratio_low=1.0 trainer.clip_ratio_high=1.0 \
    trainer.loss_agg_mode=token-mean \
    trainer.checkpoint.save_contents='["model"]'

配置默认值 warmup_steps_ratio=0.1verl/trainer/config/sft_trainer.yaml 提供,即前 10% steps 做 linear warmup,其后按 cosine 衰减。

4. Experimental Setup (实验设置)

Base models(均使用 pretrained 而非 instruct 版本,避免 alignment 混淆):

  • Qwen3-Base:1.7B / 4B / 8B / 14B;
  • InternLM2.5-20B-Base;
  • Qwen2.5-Base:1.5B / 3B / 7B / 14B(验证跨家族鲁棒性)。

数据集(全部 20k 样本,query 相同便于对照):

NameQuery 来源Response 来源备注
Math-CoT-20k(默认)OpenR1-Math-220k 子集Qwen3-32B + think,math-verify 过滤<think>
Math-NoCoT-20k同上删除 <think>...</think>保留 step-by-step 总结
NuminaMath-20k同上人工解质量参差,多数未含 CoT
Countdown-CoT-20kCountdown 游戏Qwen3-32B + think4 元运算抵达目标
DeepSeek-R1-20k同 Math-CoT 的 queryDeepSeek-R1 生成用于 teacher 消融

Training protocol(默认,除非说明):AdamW,LR 5e-5,BS 256,cosine LR schedule,8 epochs,max len 20k tokens,FSDP,bf16。

Evaluation suite

  • ID reasoning:MATH500(avg@3)、AIME24(avg@10);
  • OOD reasoning:LCB v2(avg@3)、GPQA-Diamond(avg@3)、MMLU-Pro(pass@1);
  • General capabilities:IFEval(strict instruction-level pass@1)、AlpacaEval 2.0(Llama-3.1-8B-Instruct-RM-RB2 reward)、HaluEval(pass@1)、TruthfulQA(helpful via official judge);
  • Safety:HEx-PHI(ASR w/ GPT-4.1 judge;score≥5 视为成功攻击)。

解码temperature=0.6max_tokens=32,768,zero-shot。

5. Experimental Results (实验结果)

5.1 Replication:短 epoch 会错误地给出”SFT 不泛化”结论

Figure 2 解读:按 Huan et al. (2025) 的短 epoch 协议,Math-CoT-20k 训 1 epoch 后在 Qwen3-14B-Base 上的效果:ID reasoning 强提升(MATH500 +12.7%、AIME24 +29.7%),OOD reasoning 仅有限提升(LCB v2 +2.9%、GPQA-D +4.7%、MMLU-Pro +7.5%),general capability 出现明显退化(IFEval −9.8%、AlpacaEval RM −0.11、TruthfulQA −15.1%,仅 HaluEval +21.3%)。这一图恰好复现了文献中”SFT 不 OOD 泛化”的经典画面——但看起来像”SFT 无效”的原因其实是 under-optimization,而非 SFT 目标的固有缺陷。

5.2 完整 8 epoch:dip-and-recovery & 最终 OOD 泛化

见 Figure 3(上文已贴)。8 epoch 后 Qwen3-14B 达到:

BenchmarkBase1-epoch SFT8-epoch SFT (Default)
MATH50077.8%+12.7% → 90.5%95.1%
AIME2414.7%44.4%66.0%
LCB v237.5%40.4%55.1%
GPQA-D44.1%48.8%63.3%
MMLU-Pro61.8%69.3% (+7.5pp)74.4%
IFEval64.2%54.4%68.9%
AlpacaEval RM0.530.421.42
HaluEval54.7%72.8%
TruthfulQA (helpful)94.4%95.6%

即不仅 ID 提升,OOD 与大多数 general capability 也都回归并超过了 base model。

5.3 数据对比:Tab. 2

见 §3.5 对照设计。核心数字(Qwen3-14B 为例):

数据配置MATH500AIME24LCB v2GPQA-DMMLU-ProIFEvalAlpacaEval (RM)HaluEvalTruthfulQA
Base77.8%14.7%37.5%44.1%61.8%64.2%0.5354.7%94.4%
Math-CoT-20k95.1%66.0%55.1%63.3%74.4%68.9%1.4272.8%95.6%
Math-NoCoT-20k82.4%17.0%40.3%48.3%69.1%71.7%2.1170.9%100%
NuminaMath-20k74.8%14.0%20.4%38.4%59.0%52.8%−0.4562.7%88.6%
Countdown-CoT-20k91.5%41.7%43.8%53.0%65.4%61.3%1.3672.3%92.4%

关键观察:

  • Math-CoT 全面胜出 reasoning:MATH500/AIME24/LCB v2/GPQA-D/MMLU-Pro/HaluEval 均最高;
  • Math-NoCoT 在 IF / AlpacaEval / TruthfulQA 更强:说明 instruction following 类任务更需要”短平快”的答案格式,long-CoT 反而有害;
  • NuminaMath(低质量)broadly 下降:让人误以为 “SFT 不泛化”;
  • Countdown-CoT 跨域迁移明显:只学了一个 toy 算术 + CoT,却能在 MATH500 / LCB v2 / GPQA-D 上显著超过 base,说明迁移的是 procedural 结构。

5.4 Sec. 5 模型能力扫描(Qwen3-1.7B/4B/8B/14B)

见 Figure 5。下方是额外曲线:

Figure 扩展解读:上图是 4 个模型大小在各 benchmark 上的 score 轨迹,下图是对应 response length 轨迹。(1)1.7B 在 GPQA-D / LCB v2 等 OOD 上永远未能 recovery,MMLU-Pro 也持续下降;(2)4B 有微弱 recovery 但幅度小;(3)14B 是最佳,且 response length 回落到最低位。这是 “generalization requires sufficient model capability” 的直接证据。

Qwen2.5 家族也重复此结论:

Figure 扩展解读:Qwen2.5 1.5B/3B/7B/14B 的 score/length 曲线呈相同梯度。由于 Qwen2.5 系列绝对分值较低(与 Qwen3 不同家族),这进一步排除”只是某个模型的巧合”的可能。

5.5 安全性对照

见 Figure 6 (HEx-PHI ASR) 与 Figure 7 (word cloud)。总结:

  • 3 个模型在 Math-CoT-20k SFT 后 ASR 从约 515% 上升到 3060%(多步);
  • Math-NoCoT-20k 的安全下降只有约 5 pp;
  • “safety drop 来自 procedural patterns 而非 math 内容”是论文做出的最重要一个因果陈述

5.6 其他 benchmark 总表(8-epoch)

Figure 8 解读:覆盖 MATH500 / AIME24 / LCB v2 / GPQA-D / MMLU-Pro / IFEval / AlpacaEval / HaluEval 的 8-epoch 训练轨迹;所有 9 个 benchmark 的 dip-and-recovery 模式被同时呈现。图中还附带 Qwen2.5 的 cross-family 对照,作为 App. C.2 的补充证据。

5.7 Discussion & 局限

  • 条件性而非否定性:论文并非声称 “RL 不比 SFT 好”,而是强调 SFT 与 RL 的对比应该在 comparable 优化/数据/能力条件下进行。
  • SFT 仍可作为 reasoning 能力的强基线:在合适条件下,SFT 可以得到 +10~+15 pp 的 OOD 提升,足以挑战已有”SFT 只能记忆”的叙事。
  • 安全退化是真实代价:即便 reasoning 泛化,safety 也会下降,因此 post-training 流程中必须显式考虑 safety-preserving 后处理(RLHF / 过滤 / refusal 增强)。
  • 局限
    • 主要限于 math-only training data,尚未系统扫描 code / science domain 的 reasoning SFT;
    • Base models 仅限 Qwen/InternLM 两个家族;
    • 评测主要以 pass@k / RM / judge-based 为主,未涉及人类偏好实验;
    • 对 “response length 是优化阶段的诊断指标” 仍是相关性证据,缺乏严格的因果实验。