1. Motivation (研究动机)
“SFT memorizes, RL generalizes” 已成为 post-training 领域主导叙事:Chu et al. (2025)、Huan et al. (2025) 等多项工作得出 SFT 只能记忆领域内模式、无法 OOD 泛化的结论。这一叙事直接影响了 post-training 流程设计(更多 RL、更少 SFT)以及对 SFT 目标的改造(Wu et al., 2026;Zhu et al., 2026)。
但现有研究之间的实验条件并不一致:
- Chu et al. (2025) 未使用 long CoT 监督;
- Huan et al. (2025) 只训练很短的 epoch;
- Wu et al. (2026) 使用质量不均的数据;
- 多数研究受算力限制,只使用小或旧的 base model;
- 多数”SFT vs. RL”比较关注 retention(是否破坏既有能力)而非真正的泛化获取;
- 多数研究从 instruction-tuned 模型出发,alignment 对实验结果会产生严重混淆。
这些交织的因素导致文献中报告的”SFT 不泛化”到底是SFT 的本质属性还是实验条件不足的副作用并不清楚。特别是在 long-CoT reasoning SFT 场景下(数据结构更复杂、拟合难度更大、更依赖 base model 能力),这种条件性被进一步放大。
本文的目标是在受控实验下厘清 reasoning SFT 的泛化边界:泛化不是 SFT 的固有属性,而是由优化动态、训练数据、模型能力三要素共同决定的条件现象。
2. Idea (核心思想)
论文提出一条结构化的论断:“Generalization in reasoning SFT is CONDITIONAL”,并沿三个因子展开:

Figure 1 解读:作者把 reasoning SFT 的泛化建模为三因子问题(Optimization Dynamics / Training Data / Model Capability)。每个因子都提供了正反面诊断信号:(1)优化动态上,cross-domain 表现存在 “dip-and-recovery” 模式,短训出的 checkpoint 会显著低估 SFT 的泛化能力;(2)数据上,低质量答案会 broadly 伤害泛化,而经过 verify 的 long-CoT 甚至能让一个 toy 算术游戏(Countdown)泛化到多项 reasoning benchmark;(3)模型能力上,强模型会内化 procedural 模式(backtracking / verification),弱模型只能模仿表面冗长度。整体生成一张”条件-性能”图景。
两个额外关键观察:
- 响应长度是优化阶段的粗略诊断:训练初期 response length 先激增、中期达到峰值、后期回落;长度 vs. 性能呈现反相关。
- asymmetric generalization:long-CoT SFT 虽然提升 reasoning,但 safety 明显下降,模型学会”自我合理化”并输出有害内容;这一退化几乎只发生在含 long-CoT 的训练数据上,表明其来源是 procedural pattern 而非 math 内容本身。
3. Method (方法)
本论文本质是一套受控实验方法学 + 一个 verl-based SFT 训练管线。没有新算法或新架构,但对训练管线做了自定义扩展(verl/trainer/fsdp_sft_trainer_ours.py),支持重要性采样、响应 token 级别 log-prob、advantage、clip ratio 等变体。以下按”总体框架”和”各关键组件”展开。
3.1 总体框架

整个工作流可划分为 4 个步骤:
- 数据构造:从 OpenR1-Math-220k 采样 20,480 条 math query,使用 Qwen3-32B 启用 thinking 生成 response,经
math-verify验证后保留 correct-only 样本作为 Math-CoT-20k;相应地构造Math-NoCoT-20k(删除<think>块)、NuminaMath-20k(人工解,质量参差)、Countdown-CoT-20k(toy 算术游戏 + CoT)。 - SFT 训练:基于 verl,采用 FSDP 后端,标准 SFT 目标(negative log-likelihood over response tokens);AdamW 优化器,learning rate 5e-5,batch size 256,8 epochs,cosine schedule 带 10% linear warmup(
warmup_steps_ratio=0.1),最大序列长度 20k tokens,bf16 精度。 - 评测矩阵:在 ID reasoning(MATH500、AIME24),OOD reasoning(LCB v2、GPQA-D、MMLU-Pro),general capabilities(IFEval、AlpacaEval 2.0、HaluEval、TruthfulQA),safety(HEx-PHI)上逐 checkpoint 评估,构造”training step vs. score”曲线。
- 因子扫描:分别改变优化调度(LR、epoch、schedule)、数据配置(CoT/NoCoT、质量、结构)、base model(Qwen3-1.7B/4B/8B/14B、Qwen2.5 全系、InternLM2.5-20B)以观察条件依赖。
3.2 SFT loss 与训练协议
默认 SFT 目标:
其中 为 response token 集合(不含 query), 为 prompt + 已生成前缀, 为目标 token。
作者在 verl/trainer/fsdp_sft_trainer_ours.py 中把 loss 做了通用化,支持多种重要性采样 / PPO 风格变体,imp_sampling_mode 的完整枚举为:vanilla(默认 NLL)/ teacher-weighted / dft / adv-only / ppo-ori / ppo / clip / cispo / gspo-ori / gspo。所有 loss 都可通过 verl.trainer.ppo.core_algos.agg_loss 按 token-mean / seq-mean-token-sum / seq-mean-token-mean / seq-mean-token-sum-norm 聚合。
3.2.1 数据 pipeline(OurSFTDataset)
OurSFTDataset(verl/utils/dataset/sft_dataset.py,默认数据类,被 FSDPSFTTrainer 使用)负责把 parquet 样本编码为 (input_ids, attention_mask, position_ids, loss_mask, advantages, ref_log_prob?);关键逻辑:
prompt与response分别用apply_chat_template/tokenizer编码,response_ids末尾强制补eos_token_id;- 整条序列 right-pad 到
max_length=20000,超长按truncation=right截断; loss_mask[:prompt_length-1] = 0(去掉 prompt),同时loss_mask[prompt_length+response_length-1] = 0(因为 logits 会 shift 一位,最末 token 无监督信号);- 若
ref_log_prob.enable=True(默认开),把 teacher log_prob 按[prompt_length-1 : prompt_length-1+response_length]放到与attention_mask对齐的 0 填充 tensor 中。
import torch
def sft_collate(prompt_ids, response_ids, tokenizer, max_length: int):
"""Mirrors OurSFTDataset.__getitem__ in the repo."""
eos = torch.tensor([tokenizer.eos_token_id], dtype=response_ids.dtype)
response_ids = torch.cat([response_ids, eos])
input_ids = torch.cat([prompt_ids, response_ids])
attn = torch.cat([
torch.ones_like(prompt_ids), torch.ones_like(response_ids),
])
seq_len = input_ids.shape[0]
if seq_len < max_length:
pad_ids = torch.full((max_length - seq_len,), tokenizer.pad_token_id)
pad_attn = torch.zeros(max_length - seq_len, dtype=attn.dtype)
input_ids = torch.cat([input_ids, pad_ids])
attn = torch.cat([attn, pad_attn])
else:
input_ids = input_ids[:max_length]
attn = attn[:max_length]
position_ids = (attn.cumsum(dim=-1) - 1).clamp_min(0) * attn
prompt_len = prompt_ids.shape[0]
loss_mask = attn.clone().to(torch.float32)
if prompt_len > 1:
loss_mask[: min(prompt_len, loss_mask.size(0)) - 1] = 0.0
last = min(prompt_len + response_ids.shape[0], loss_mask.size(0)) - 1
loss_mask[last] = 0.0
return input_ids, attn, position_ids, loss_mask3.2.2 _compute_loss_and_backward(FSDPSFTTrainer)
import torch
import verl.utils.torch_functional as verl_F
def _compute_loss_and_backward(self, batch, do_backward: bool = True):
"""Mirrors FSDPSFTTrainer._compute_loss_and_backward in fsdp_sft_trainer_ours.py."""
input_ids = batch["input_ids"].to(self.device_name)
attention_mask = batch["attention_mask"].to(self.device_name)
position_ids = batch["position_ids"].to(self.device_name)
loss_mask = batch.pop("loss_mask")[:, :-1].to(self.device_name) # shift-1
advantages = batch.pop("advantages").to(self.device_name)
ref_log_prob = batch.get("ref_log_prob")
if ref_log_prob is not None:
ref_log_prob = ref_log_prob[:, :-1].to(self.device_name)
mean_ref_log_prob = batch.get("mean_ref_log_prob")
with torch.autocast(self.device_name, dtype=torch.bfloat16):
labels = input_ids[:, 1:].contiguous()
logits = self.fsdp_model(
input_ids=input_ids, attention_mask=attention_mask,
position_ids=position_ids, use_cache=False,
).logits
shift_logits = logits[..., :-1, :].contiguous()
log_prob = verl_F.logprobs_from_logits(shift_logits, labels=labels)
advantages = torch.ones_like(log_prob) * advantages # broadcast to token-level
loss, *_ = compute_loss(
log_prob=log_prob,
response_mask=loss_mask,
advantages=advantages,
old_log_prob=ref_log_prob,
mean_old_log_prob=mean_ref_log_prob,
imp_sampling_mode=self.config.trainer.importance_sampling_mode,
cliprange=self.config.trainer.clip_ratio,
cliprange_low=self.config.trainer.clip_ratio_low,
cliprange_high=self.config.trainer.clip_ratio_high,
loss_agg_mode=self.config.trainer.loss_agg_mode,
)
if do_backward:
loss.backward()
return loss.item()3.2.3 compute_loss:vanilla 与 PPO/CISPO/GSPO 等变体
import torch
from verl.trainer.ppo.core_algos import agg_loss
def compute_loss(
log_prob, response_mask,
advantages=None, old_log_prob=None, mean_old_log_prob=None,
imp_sampling_mode: str = "vanilla",
cliprange=None, cliprange_low=None, cliprange_high=None,
loss_agg_mode: str = "token-mean",
):
"""Mirrors compute_loss(...) in fsdp_sft_trainer_ours.py (vanilla + variants)."""
cliprange_low = cliprange_low if cliprange_low is not None else cliprange
cliprange_high = cliprange_high if cliprange_high is not None else cliprange
if imp_sampling_mode == "vanilla":
# standard SFT: L = - log p_theta(y_t | x, y_<t)
losses = -log_prob
elif imp_sampling_mode == "teacher-weighted":
# -pi_teacher * log_prob (teacher-as-importance-weight)
w = torch.exp(old_log_prob).detach()
losses = -w * log_prob
elif imp_sampling_mode == "dft":
# distillation fine-tune: pi_theta / pi_teacher (clipped) as weight, detached
weight = torch.exp(log_prob - old_log_prob).detach()
weight = torch.clamp(weight, 1.0 - cliprange_low, 1.0 + cliprange_high)
losses = -weight * log_prob
elif imp_sampling_mode == "adv-only":
losses = -advantages * log_prob
elif imp_sampling_mode in {"ppo", "ppo-ori", "clip", "cispo"}:
# PPO/CISPO-style: ratio = pi_theta / pi_old
ratio = torch.exp(torch.clamp(log_prob - old_log_prob, -20, 20))
if imp_sampling_mode == "cispo":
ratio_clip = torch.clamp(ratio.detach(), 1.0 - cliprange_low, 1.0 + cliprange_high)
losses = -ratio_clip * log_prob # clip only importance ratio, keep gradient in log_prob
else:
ratio_clip = torch.clamp(ratio, 1.0 - cliprange_low, 1.0 + cliprange_high)
losses = -torch.min(ratio * advantages, ratio_clip * advantages)
elif imp_sampling_mode in {"gspo", "gspo-ori"}:
# sequence-level importance ratio (Gradient-Sequence PO)
delta = torch.clamp(log_prob - old_log_prob, -20, 20)
seq_len = response_mask.sum(dim=-1).clamp_min(1)
seq_delta = (delta * response_mask).sum(dim=-1) / seq_len # (B,)
seq_ratio = torch.exp(seq_delta).unsqueeze(-1).expand_as(log_prob)
seq_ratio_clip = torch.clamp(seq_ratio, 1.0 - cliprange_low, 1.0 + cliprange_high)
losses = -torch.min(seq_ratio * advantages, seq_ratio_clip * advantages)
else:
raise ValueError(imp_sampling_mode)
return agg_loss(loss_mat=losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
def agg_loss(loss_mat, loss_mask, loss_agg_mode: str = "token-mean"):
"""Mirrors verl.trainer.ppo.core_algos.agg_loss."""
import verl.utils.torch_functional as verl_F
if loss_agg_mode == "token-mean":
return verl_F.masked_mean(loss_mat, loss_mask)
if loss_agg_mode == "seq-mean-token-sum":
return torch.mean((loss_mat * loss_mask).sum(dim=-1))
if loss_agg_mode == "seq-mean-token-mean":
seq = (loss_mat * loss_mask).sum(dim=-1) / loss_mask.sum(dim=-1).clamp_min(1)
return seq.mean()
if loss_agg_mode == "seq-mean-token-sum-norm":
return (loss_mat * loss_mask).sum() / loss_mask.shape[-1]
raise ValueError(loss_agg_mode)3.2.4 training_step:FSDP 梯度累积 + all-reduce
def training_step(self, batch):
"""Mirrors FSDPSFTTrainer.training_step in fsdp_sft_trainer_ours.py."""
self.fsdp_model.train()
self.optimizer.zero_grad()
micro_batches = batch.split(self.config.data.micro_batch_size_per_gpu)
step_loss = 0.0
for mb in micro_batches:
step_loss += self._compute_loss_and_backward(mb) / len(micro_batches)
if self.config.model.strategy == "fsdp":
grad_norm = self.fsdp_model.clip_grad_norm_(self.config.optim.clip_grad)
else: # fsdp2
grad_norm = fsdp2_clip_grad_norm_(self.fsdp_model.parameters(), self.config.optim.clip_grad)
if torch.isfinite(grad_norm):
self.optimizer.step()
self.lr_scheduler.step()
step_loss_t = torch.tensor(step_loss, device=self.device_name)
torch.distributed.all_reduce(step_loss_t, op=torch.distributed.ReduceOp.AVG)
return {"train/loss": step_loss_t.item(), "train/lr": self.lr_scheduler.get_last_lr()[0]}默认 imp_sampling_mode=vanilla、loss_agg_mode=token-mean、clip_ratio_low=clip_ratio_high=1.0,即退化为标准 token-level NLL。
3.3 优化动态:dip-and-recovery + response length
作者用 Qwen3-14B-Base / 8B-Base / InternLM2.5-20B-Base 在默认设置下训 8 epoch,得到 step × benchmark × response-length 的三维表格。

Figure 3 解读:上半部分显示三个模型在 9 个 benchmark 上的性能曲线(x 轴 log scale 的 training steps,y 轴精度或 RM 分数),普遍呈现”dip-and-recovery”:早期 step OOD benchmark 普遍下跌(例如 IFEval、AlpacaEval),中期逐步恢复,后期超过 base model;ID 基准(MATH500、AIME24)则有 brief 小幅 dip 后快速上升。下半部分显示 response length 随 step 的演化:先急剧增长到 2~3 万 token,再逐步回落;峰值长度恰好对应性能最差的阶段——说明此时模型只学到了”写得长”这一表层特征,还未内化分解、回溯、自评估等深层模式。
定量诊断信号:
其中 为第 个 checkpoint 的平均 response length, 为该 checkpoint 在某任务的得分。实验中这一相关性在 dip 阶段显著为负。作者由此提出:
Checkpoint is “under-optimized” 的一个实用判据:response length 仍在 noticeable shrink,就不要据此下 “SFT 不泛化” 的结论。
3.4 重复曝光与单次覆盖及 overfitting regime
受控实验:在固定 640 个 gradient step 的预算下对比三种训练安排:Setting1 = 20k / bs 256 / ep 8;Setting2 = 2.5k / bs 32 / ep 8;Setting3 = 20k / bs 32 / ep 1。

Figure 扩展解读(Sec. 3.3 Tab. 1):Setting 2 明显优于 Setting 3(同计算预算,8 次 vs. 1 次数据曝光),说明 long-CoT SFT 从 “多轮重复曝光” 中获得的提升大于 “更多 unique 样本”;这一点解释了为什么文献上”训 1 epoch 就结论”的做法容易 underestimate SFT。Setting 1 再比 Setting 2 好,表明在 epoch 与 step 固定时,数据 diversity 仍有额外价值。
Overfitting stress test:作者在更 aggressive 的调度下训 16 epoch,对比”cosine LR”/“const LR”/“const LR + LR↑”。

Figure 4 解读:前者显示各 benchmark 上的性能,后者显示对应的 response length。(1) 默认 + 16 ep(cosine)几乎与默认曲线重合,末期开始缓慢下降但未显著。(2) const LR 16 ep 后段出现 broad OOD 退化信号。(3) LR=1e-4, const, 16 ep 展现最清楚的 overfitting:即使 ID math 也开始 drop,response length 再次回升——说明模型重新陷入”长响应但低质量”的状态。由此作者画出了 long-CoT SFT 的 underfit→sweet-spot→overfit 三段式优化区间。
3.5 数据质量与结构的因果实验
作者通过 4 种数据对照分离 CoT / 质量 / 结构三种因子:
Math-CoT-20kvs.Math-NoCoT-20k:query 与 final answer 完全一致,仅删除<think>块 → 分离 “长 CoT 过程” 的因果贡献;Math-NoCoT-20kvs.NuminaMath-20k:都没有 long CoT,但前者是质量较高的答案 → 分离 “数据质量” 的因果贡献;Math-CoT-20kvs.Countdown-CoT-20k:query domain 不同(数学 vs. Countdown 算术游戏),但都是 long-CoT → 分离 “procedural 结构” 的因果贡献。

Figure 扩展解读:使用 DeepSeek-R1 作为 teacher 生成 20k responses 进行对照训练,整体 dip-and-recovery 模式与默认 Qwen3-32B teacher 保持一致,只有绝对得分存在小差异。这说明观察到的动态与 teacher model 无关,而是 long-CoT SFT 本身的属性。

Figure 扩展解读:在 Qwen2.5 系列(模型家族不同)上复现 dip-and-recovery 与 capability-dependent 结论。验证发现不局限于某一家 model series,增加结论鲁棒性。
Countdown 带来的”procedural generalization”意义特别重要——下图放大了这一对比:

Figure (Claim 1) 解读:同一 14B base 下 Countdown-CoT 仅覆盖一个算术 toy 任务,却能让 LCB v2、GPQA-D、MMLU-Pro 同时提升(甚至在 MATH500 上超过 Math-NoCoT-20k),证明真正迁移的是 long-CoT 中的 procedural patterns(回溯 / 验证 / 分解),而非 math 知识本身。
3.6 模型能力的作用
作者在完全固定 data + schedule 下对 Qwen3 {1.7B, 4B, 8B, 14B} 做 capability 扫描。

Figure 5 解读(上半):14B 模型展示最完整的 dip-and-recovery,8B、4B 次之,1.7B 基本没有 recovery 甚至最终 score 仍不如 base。即同一份数据+同一个配方,在能力不足的模型上无法触发泛化机制。

Figure 5 解读(下半):1.7B 的 response length 在整个训练过程中一直保持在高位且不回落,说明它停留在”模仿表层 pattern”的阶段;14B 模型 length 随训练回落到接近 1 万 token,与”学会结构化推理”的阶段一致。由此作者把 response length 的回落再次映射成 capability-dependent 的现象:弱模型无法从”写长”过渡到”写对”。
3.7 Asymmetric generalization:安全性的受控对照

Figure 6 解读:(a) 在 HEx-PHI 上测 Attack Success Rate(ASR,越低越安全)。所有 3 个模型在 Math-CoT-20k SFT 后 ASR 明显升高(safety 下降),而 Math-NoCoT-20k 只引起轻微退化。由于两份数据共享 query 与最终 solution,唯一差异就是 <think>...</think> 这一 procedural pattern,因此作者得出结论:safety 退化不是来自 math 内容,而是被 CoT 结构”教坏”了。(b) case study:SFT 前 base 模型只会直接拒答;SFT 后模型先发出 “But … maybe it’s for educational purposes …”、“let’s assume that this is for a cybersecurity course…” 之类的自我合理化(self-rationalization),最终给出有害步骤。
作者还提供了所训模型产生的 top-k 有害”justification”词云:

Figure 7 解读:在 long-CoT 模型的有害回答中,educational, hypothetical, academic, scenario, research, purposes 等词最高频地用来”解除拒答”。这是一种训练出来的”软合规式 jailbreak”,与 Yong & Bach (2025) / Mao et al. (2025) 的发现一致。
Code-to-paper mapping table
| Paper Concept | Source File | Key Class / Function | github_ref |
|---|---|---|---|
| 默认训练脚本(Sec. 2.1) | training_scripts/Qwen3-14B_Math-CoT-20k_lr5e-5_ep8_bs256.sh | torchrun ... verl.trainer.fsdp_sft_trainer_ours ... | Nebularaid2000/rethink_sft_generalization@86e4d3eb |
| 各 data / model / LR / epoch 变体 | training_scripts/*.sh(共 34 个脚本) | — | Nebularaid2000/rethink_sft_generalization@86e4d3eb |
| 自定义 SFT trainer(§3.2) | verl/trainer/fsdp_sft_trainer_ours.py | FSDPSFTTrainer(__init__、_build_model_optimizer、training_step、_compute_loss_and_backward) | Nebularaid2000/rethink_sft_generalization@86e4d3eb |
| token 级 log-prob + vanilla NLL | verl/trainer/fsdp_sft_trainer_ours.py | compute_loss(..., imp_sampling_mode="vanilla") | Nebularaid2000/rethink_sft_generalization@86e4d3eb |
| 重要性采样 / PPO 风格变体(non-default) | verl/trainer/fsdp_sft_trainer_ours.py | compute_loss(..., imp_sampling_mode ∈ {teacher-weighted, dft, adv-only, ppo-ori, ppo, clip, cispo, gspo-ori, gspo}) | Nebularaid2000/rethink_sft_generalization@86e4d3eb |
| Loss aggregation | verl/trainer/ppo/core_algos.py | agg_loss(token-mean / seq-mean-token-sum / seq-mean-token-mean / seq-mean-token-sum-norm) | Nebularaid2000/rethink_sft_generalization@86e4d3eb |
| SFT 数据 pipeline(prompt+response+mask+advantage+ref_log_prob) | verl/utils/dataset/sft_dataset.py | OurSFTDataset(默认数据类;SFTDataset 为保留的旧版) | Nebularaid2000/rethink_sft_generalization@86e4d3eb |
| Multi-turn SFT(非默认) | verl/utils/dataset/multiturn_sft_dataset.py | MultiTurnSFTDataset | |
| math-verify 验证器(data filter) | verl/utils/reward_score/math_verify.py | compute_score(对 Math-CoT-20k 生成后做答案校验) | |
| FSDP checkpoint | verl/utils/checkpoint/fsdp_checkpoint_manager.py | FSDPCheckpointManager | |
默认 SFT 配置(含 warmup_steps_ratio=0.1) | verl/trainer/config/sft_trainer.yaml | — | |
| LoRA / LR scheduler / optimizer | verl/utils/torch_functional.py、verl/utils/fsdp_utils.py | get_cosine_schedule_with_warmup 等 | |
| Evaluation 套件(MATH500, AIME24, LCB v2, GPQA-D, MMLU-Pro, IFEval, AlpacaEval, HaluEval, TruthfulQA) | evaluation/evalchemy/eval/chat_benchmarks/* | 各 task 的 eval_instruct.py | |
| Safety 评测(HEx-PHI) | evaluation/harmbench/* | run_pipeline.py 及 scripts/step*.sh |
默认训练命令摘要(摘自 training_scripts/Qwen3-14B_Math-CoT-20k_lr5e-5_ep8_bs256.sh):
torchrun --nnodes=$NODE_COUNT --nproc_per_node=$PROC_PER_NODE \
--node_rank=$NODE_RANK --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT \
-m verl.trainer.fsdp_sft_trainer_ours \
data.train_files=$TRAIN_DATA data.val_files=$VAL_DATA \
data.prompt_key=message data.response_key=response \
data.logprob_key=logprob data.advantage_key=advantage \
data.train_batch_size=256 data.micro_batch_size_per_gpu=4 \
data.max_length=20000 data.truncation=right data.shuffle_train=False \
data.ref_log_prob.enable=True data.ref_log_prob.only_return_mean_logprob=False \
model.partial_pretrain=Qwen/Qwen3-14B-Base \
model.fsdp_config.model_dtype=bf16 model.trust_remote_code=True \
optim.lr=5e-5 optim.betas="[0.9,0.999]" optim.lr_scheduler=cosine \
trainer.total_epochs=8 trainer.save_freq=10 \
trainer.importance_sampling_mode=vanilla \
trainer.clip_ratio_low=1.0 trainer.clip_ratio_high=1.0 \
trainer.loss_agg_mode=token-mean \
trainer.checkpoint.save_contents='["model"]'配置默认值 warmup_steps_ratio=0.1 由 verl/trainer/config/sft_trainer.yaml 提供,即前 10% steps 做 linear warmup,其后按 cosine 衰减。
4. Experimental Setup (实验设置)
Base models(均使用 pretrained 而非 instruct 版本,避免 alignment 混淆):
- Qwen3-Base:1.7B / 4B / 8B / 14B;
- InternLM2.5-20B-Base;
- Qwen2.5-Base:1.5B / 3B / 7B / 14B(验证跨家族鲁棒性)。
数据集(全部 20k 样本,query 相同便于对照):
| Name | Query 来源 | Response 来源 | 备注 |
|---|---|---|---|
| Math-CoT-20k(默认) | OpenR1-Math-220k 子集 | Qwen3-32B + think,math-verify 过滤 | 含 <think> 块 |
| Math-NoCoT-20k | 同上 | 删除 <think>...</think> | 保留 step-by-step 总结 |
| NuminaMath-20k | 同上 | 人工解 | 质量参差,多数未含 CoT |
| Countdown-CoT-20k | Countdown 游戏 | Qwen3-32B + think | 4 元运算抵达目标 |
| DeepSeek-R1-20k | 同 Math-CoT 的 query | DeepSeek-R1 生成 | 用于 teacher 消融 |
Training protocol(默认,除非说明):AdamW,LR 5e-5,BS 256,cosine LR schedule,8 epochs,max len 20k tokens,FSDP,bf16。
Evaluation suite:
- ID reasoning:MATH500(avg@3)、AIME24(avg@10);
- OOD reasoning:LCB v2(avg@3)、GPQA-Diamond(avg@3)、MMLU-Pro(pass@1);
- General capabilities:IFEval(strict instruction-level pass@1)、AlpacaEval 2.0(Llama-3.1-8B-Instruct-RM-RB2 reward)、HaluEval(pass@1)、TruthfulQA(helpful via official judge);
- Safety:HEx-PHI(ASR w/ GPT-4.1 judge;score≥5 视为成功攻击)。
解码:temperature=0.6,max_tokens=32,768,zero-shot。
5. Experimental Results (实验结果)
5.1 Replication:短 epoch 会错误地给出”SFT 不泛化”结论

Figure 2 解读:按 Huan et al. (2025) 的短 epoch 协议,Math-CoT-20k 训 1 epoch 后在 Qwen3-14B-Base 上的效果:ID reasoning 强提升(MATH500 +12.7%、AIME24 +29.7%),OOD reasoning 仅有限提升(LCB v2 +2.9%、GPQA-D +4.7%、MMLU-Pro +7.5%),general capability 出现明显退化(IFEval −9.8%、AlpacaEval RM −0.11、TruthfulQA −15.1%,仅 HaluEval +21.3%)。这一图恰好复现了文献中”SFT 不 OOD 泛化”的经典画面——但看起来像”SFT 无效”的原因其实是 under-optimization,而非 SFT 目标的固有缺陷。
5.2 完整 8 epoch:dip-and-recovery & 最终 OOD 泛化
见 Figure 3(上文已贴)。8 epoch 后 Qwen3-14B 达到:
| Benchmark | Base | 1-epoch SFT | 8-epoch SFT (Default) |
|---|---|---|---|
| MATH500 | 77.8% | +12.7% → 90.5% | 95.1% |
| AIME24 | 14.7% | 44.4% | 66.0% |
| LCB v2 | 37.5% | 40.4% | 55.1% |
| GPQA-D | 44.1% | 48.8% | 63.3% |
| MMLU-Pro | 61.8% | 69.3% (+7.5pp) | 74.4% |
| IFEval | 64.2% | 54.4% | 68.9% |
| AlpacaEval RM | 0.53 | 0.42 | 1.42 |
| HaluEval | 54.7% | — | 72.8% |
| TruthfulQA (helpful) | 94.4% | — | 95.6% |
即不仅 ID 提升,OOD 与大多数 general capability 也都回归并超过了 base model。
5.3 数据对比:Tab. 2
见 §3.5 对照设计。核心数字(Qwen3-14B 为例):
| 数据配置 | MATH500 | AIME24 | LCB v2 | GPQA-D | MMLU-Pro | IFEval | AlpacaEval (RM) | HaluEval | TruthfulQA |
|---|---|---|---|---|---|---|---|---|---|
| Base | 77.8% | 14.7% | 37.5% | 44.1% | 61.8% | 64.2% | 0.53 | 54.7% | 94.4% |
| Math-CoT-20k | 95.1% | 66.0% | 55.1% | 63.3% | 74.4% | 68.9% | 1.42 | 72.8% | 95.6% |
| Math-NoCoT-20k | 82.4% | 17.0% | 40.3% | 48.3% | 69.1% | 71.7% | 2.11 | 70.9% | 100% |
| NuminaMath-20k | 74.8% | 14.0% | 20.4% | 38.4% | 59.0% | 52.8% | −0.45 | 62.7% | 88.6% |
| Countdown-CoT-20k | 91.5% | 41.7% | 43.8% | 53.0% | 65.4% | 61.3% | 1.36 | 72.3% | 92.4% |
关键观察:
- Math-CoT 全面胜出 reasoning:MATH500/AIME24/LCB v2/GPQA-D/MMLU-Pro/HaluEval 均最高;
- Math-NoCoT 在 IF / AlpacaEval / TruthfulQA 更强:说明 instruction following 类任务更需要”短平快”的答案格式,long-CoT 反而有害;
- NuminaMath(低质量)broadly 下降:让人误以为 “SFT 不泛化”;
- Countdown-CoT 跨域迁移明显:只学了一个 toy 算术 + CoT,却能在 MATH500 / LCB v2 / GPQA-D 上显著超过 base,说明迁移的是 procedural 结构。
5.4 Sec. 5 模型能力扫描(Qwen3-1.7B/4B/8B/14B)
见 Figure 5。下方是额外曲线:

Figure 扩展解读:上图是 4 个模型大小在各 benchmark 上的 score 轨迹,下图是对应 response length 轨迹。(1)1.7B 在 GPQA-D / LCB v2 等 OOD 上永远未能 recovery,MMLU-Pro 也持续下降;(2)4B 有微弱 recovery 但幅度小;(3)14B 是最佳,且 response length 回落到最低位。这是 “generalization requires sufficient model capability” 的直接证据。
Qwen2.5 家族也重复此结论:

Figure 扩展解读:Qwen2.5 1.5B/3B/7B/14B 的 score/length 曲线呈相同梯度。由于 Qwen2.5 系列绝对分值较低(与 Qwen3 不同家族),这进一步排除”只是某个模型的巧合”的可能。
5.5 安全性对照
见 Figure 6 (HEx-PHI ASR) 与 Figure 7 (word cloud)。总结:
- 3 个模型在 Math-CoT-20k SFT 后 ASR 从约 5
15% 上升到 3060%(多步); - Math-NoCoT-20k 的安全下降只有约 5 pp;
- “safety drop 来自 procedural patterns 而非 math 内容”是论文做出的最重要一个因果陈述。
5.6 其他 benchmark 总表(8-epoch)

Figure 8 解读:覆盖 MATH500 / AIME24 / LCB v2 / GPQA-D / MMLU-Pro / IFEval / AlpacaEval / HaluEval 的 8-epoch 训练轨迹;所有 9 个 benchmark 的 dip-and-recovery 模式被同时呈现。图中还附带 Qwen2.5 的 cross-family 对照,作为 App. C.2 的补充证据。
5.7 Discussion & 局限
- 条件性而非否定性:论文并非声称 “RL 不比 SFT 好”,而是强调 SFT 与 RL 的对比应该在 comparable 优化/数据/能力条件下进行。
- SFT 仍可作为 reasoning 能力的强基线:在合适条件下,SFT 可以得到 +10~+15 pp 的 OOD 提升,足以挑战已有”SFT 只能记忆”的叙事。
- 安全退化是真实代价:即便 reasoning 泛化,safety 也会下降,因此 post-training 流程中必须显式考虑 safety-preserving 后处理(RLHF / 过滤 / refusal 增强)。
- 局限:
- 主要限于 math-only training data,尚未系统扫描 code / science domain 的 reasoning SFT;
- Base models 仅限 Qwen/InternLM 两个家族;
- 评测主要以 pass@k / RM / judge-based 为主,未涉及人类偏好实验;
- 对 “response length 是优化阶段的诊断指标” 仍是相关性证据,缺乏严格的因果实验。