1. Motivation (研究动机)
- 当前 Diffusion RL post-training 已能通过 GRPO / DiffusionNFT / FlowGRPO 等方法提升 text-to-image 模型的人类偏好对齐,但它们的有效更新高度依赖 rollout group 内候选样本的相对奖励差异;扩大 rollout group size 往往能带来更强的 alignment gain。
- 瓶颈在于大模型 rollout 极贵:例如 FLUX.1-12B 这类基础模型要为每个 prompt 生成大量候选,BF16 naive scaling 会把绝大多数最终不会用于训练的候选也完整高精度采样,导致 wall-clock 与 GPU-hour 成本成为主要限制。
- 本文要解决的问题是:能否用低精度 NVFP4 先大规模便宜探索候选空间,再只把最有训练价值的高对比样本用 BF16 重生成用于优化,从而同时保留 rollout scaling 的算法收益和高精度训练的稳定性。解决后,Diffusion RL 可以把更多预算花在“找更有信息量的样本”上,而不是浪费在训练不会用到的候选图像上。
2. Idea (核心思想)
- 核心 insight:FP4 rollout 不适合直接作为训练 target,因为量化误差会污染梯度和 policy update;但它足够保留同一 prompt 内候选样本的相对 reward ranking,因此可以作为高吞吐的 proxy filter。
- Sol-RL 把 rollout 拆成两个阶段:Stage 1 用 NVFP4 compiled model 以少步数生成大候选池并排序;Stage 2 只保留 top/bottom 高对比 seed,用 BF16 full rollout 重生成高保真训练样本,再用 DiffusionNFT objective 更新 policy。
- 与 DiffusionNFT / FlowGRPO 的主要差异不在 reward objective 本身,而在 rollout 数据生产路径:它们通常直接在训练精度下采样用于优化的 group;Sol-RL 则把“探索候选”和“生成训练 target”解耦,用低精度承担搜索、高精度承担优化。
3. Method (方法)
Overall framework. Sol-RL 的整体流程是“FP4 explore, BF16 train”:每轮训练先把当前 policy 的 LoRA/权重同步到 NVFP4 inference engine,针对每个 prompt 采样 个独立 initial noise,用 个 denoising steps 做快速候选生成并用 reward model 打分;然后保留 top- 与 bottom- 的 seed,其中 ;最后用 BF16 compiled model 和完整 rollout steps 从这些 seed 重生成样本,用 DiffusionNFT loss 做 policy update。
Figure 2 解读:图中左侧是低精度 exploration branch,NVFP4 model 负责便宜地产生大候选池;中间的 reward sorting 只传递 seed / ranking 信息,而不是把 FP4 图像直接用于训练;右侧 BF16 branch 根据被选中的 top/bottom seed 重生成高保真 rollout,再进入 DiffusionNFT optimization。这个结构的关键是让量化误差只影响候选筛选,不进入最终梯度 target。
Figure 1 解读:左侧展示 FLUX.1 与 SANA 经 Sol-RL fine-tuning 后的生成样例,强调该方法不只是加速,还能保持高质量视觉输出;右侧训练曲线显示在 ImageReward 上 Sol-RL 能以更少 wall-clock 时间达到 baseline 同等 reward,最高报告 convergence speedup,并收敛到更高 reward ceiling。
3.1 Preliminaries: Diffusion RL 与 GRPO
最简单的 REINFORCE policy gradient 写作:
加入 prompt-level baseline 后为:
GRPO 对每个 prompt 生成 个候选 ,用 group 内 reward 标准化得到 advantage:
对应 PPO-style clipped surrogate objective 为:
3.2 FP4 quantization 与直接量化 rollout 的问题
FP4 把高精度 tensor 通过 group shared scale 映射到 4-bit value,论文写作:
其中 是 shared scaling factor, 是投影函数。OCP MXFP4 使用 32-element group + E8M0 scale;NVIDIA NVFP4 使用 16-element group + E4M3 scale。FP4 的吞吐优势明显,但直接把 FP4 rollout 放进 RL training pipeline 会让训练不稳定,因为优化阶段依赖的是图像/trajectory 的高保真 target,而不是只依赖相对排序。
[Large figure omitted: fig3_group.svg exceeded Cloudflare Pages single-file limit.]
Figure 3a–3c 解读:这组子图在原始 arXiv 源码中属于同一个组合图,应一起看:Figure 3a 比较不同 -in- precision setting 的 iteration time breakdown,说明 naive BF16 scaling 的主要成本集中在生成大量候选上,而 FP4 rollout 能显著降低 generation time;Figure 3b 展示 direct FP4 quantized rollout 相对 BF16 baseline 的训练性能退化和不稳定,支撑“低精度可以加速 sampling,但不能无条件替代训练 target”的负面结论;Figure 3c 显示 FP4 与 BF16 在同一 group 内 reward ranking 的分布接近对角线,说明即使绝对 reward 有扰动,候选之间的相对顺序仍高度一致,Sol-RL 正是利用这个“排序可靠、训练不可靠”的分离性质。
3.3 Two-stage rollout 的直觉
Sol-RL 的直觉类似先用低成本检索器从大库中找 hard positives / hard negatives,再用高质量生成器重建训练样本。GRPO/DiffusionNFT 需要的是 group 内有足够 reward spread 的样本来形成强 advantage signal,而不是需要所有候选都作为训练 target;因此,先用 NVFP4 找到 top/bottom seed,再用 BF16 重生成这些 seed,能把探索预算放大到 ,同时把优化实际处理的样本数控制在 。这种设计把 FP4 误差限制在“选择哪些 seed”这一离散决策中,而不是让误差连续进入 loss。
3.4 Theoretical justification: ranking preservation
附录把低精度 rollout 建模为 ODE vector field 的有界扰动。高精度 trajectory 满足:
低精度 accelerated trajectory 满足:
其中 是 FP4 rounding 与低精度 solver arithmetic 的有效扰动。若 对 是 -Lipschitz,则 Grönwall inequality 给出最终 sample deviation bound:
若 reward model 是 -Lipschitz,则 fixed seed 的 absolute reward error 也被上式诱导的 sample deviation 控制。论文进一步用 Extreme Value Theory 论证:当 group size 增大且 top/bottom 样本的 reward margin 足够大时,小扰动不容易改变极端候选集合,因此 FP4 适合作为 top/bottom seed filter。
3.5 Pseudocode: NVFP4 exploration filter
import torch
@torch.no_grad()
def fp4_exploration_filter(policy, prompts, reward_fn, *, n=96, k=24, steps=6):
fp4_policy = quantize_to_nvfp4_compiled(policy)
selected = []
for prompt in prompts:
seeds = torch.randint(0, 2**31 - 1, (n,), device="cuda")
images = []
for seed in seeds:
noise = make_initial_noise(seed, prompt)
image = fp4_policy.sample(prompt, noise=noise, num_inference_steps=steps)
images.append(image)
rewards = reward_fn(images, [prompt] * n)
order = torch.argsort(rewards)
bottom = order[: k // 2]
top = order[-k // 2 :]
keep = torch.cat([bottom, top])
selected.append({"prompt": prompt, "seeds": seeds[keep], "proxy_rewards": rewards[keep]})
return selected3.6 Pseudocode: BF16 regeneration and reward/advantage construction
import torch
@torch.no_grad()
def bf16_regenerate_rollouts(policy, selected_groups, reward_fn, *, steps=10):
batch = []
for group in selected_groups:
prompt = group["prompt"]
images, trajectories = [], []
for seed in group["seeds"]:
noise = make_initial_noise(seed, prompt)
image, traj = policy.sample_with_logprob(
prompt,
noise=noise,
num_inference_steps=steps,
dtype=torch.bfloat16,
)
images.append(image)
trajectories.append(traj)
rewards = reward_fn(images, [prompt] * len(images))
advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-4)
batch.append({"prompt": prompt, "trajectories": trajectories, "advantages": advantages})
return collate_rollout_batch(batch)3.7 Pseudocode: DiffusionNFT-style policy update
import torch
import torch.nn.functional as F
def diffusion_nft_update(policy, ref_policy, old_policy, rollout_batch, optimizer, *, beta=1e-4, adv_clip=5.0):
optimizer.zero_grad()
total_loss = 0.0
for batch in rollout_batch:
x0 = batch["x0"]
timesteps = batch["timesteps"]
cond = batch["prompt_embeds"]
advantages = torch.clamp(batch["advantages"], -adv_clip, adv_clip)
r = torch.clamp((advantages / adv_clip) / 2.0 + 0.5, 0.0, 1.0)
pred = policy(x0, timesteps, cond)
old_pred = old_policy(x0, timesteps, cond).detach()
ref_pred = ref_policy(x0, timesteps, cond).detach()
positive_loss = F.mse_loss(pred, x0, reduction="none").flatten(1).mean(dim=1)
negative_loss = F.mse_loss(2 * old_pred - pred, x0, reduction="none").flatten(1).mean(dim=1)
policy_loss = (r * positive_loss + (1.0 - r) * negative_loss).mean()
kl_loss = F.mse_loss(pred, ref_pred)
total_loss = total_loss + policy_loss + beta * kl_loss
total_loss.backward()
optimizer.step()
return total_loss.detach()3.8 Pseudocode: LoRA synchronization back to compiled inference models
@torch.no_grad()
def sync_after_update(trainable_lora, inference_models, *, adapter_name="old"):
old_lora = ema_or_linear_ramp_update(trainable_lora)
for model in inference_models.values():
base = unwrap_compiled(model)
sync_lora_to_inference(old_lora, base, adapter_name=adapter_name)
if model.requires_nvfp4:
re_quantize_in_place(base)Code reference:
main@1bfb9352(2026-04-14) — pseudocode and mapping based on this commit
| Paper Concept | Source File | Key Class/Function |
|---|---|---|
| Sol-RL run recipes and launch commands | docs/sol_rl.md | training commands for SANA / FLUX.1 / SD3.5-L |
| Shared RL config defaults | configs/sol_rl/base.py | get_config() |
| SANA experiment variants | configs/sol_rl/sana.py | _build_reward(), sana_sol_rl_*(), sana_naive_scaling_*(), sana_naive_quant_*() |
| FLUX.1 experiment variants | configs/sol_rl/flux1.py | flux1_sol_rl_*(), flux1_compile_*(), flux1_naive_quant_*() |
| SD3.5 experiment variants | configs/sol_rl/sd3.py | sd3_sol_rl_*(), sd3_compile_*(), sd3_naive_quant_*() |
| NVFP4 support and model sync | train_scripts/sol_rl/train_utils.py | NVFP4_RECIPE, ensure_transformer_engine_available(), wrap_forward_with_fp8(), replace_linear_with_te(), sync_lora_to_inference() |
| SANA rollout/training loop | train_scripts/sol_rl/train_sana.py | _rollout_for_one_prompt(), stat_tracker.update(), policy loss block, sync_lora_to_inference() |
| FLUX.1 rollout/training loop | train_scripts/sol_rl/train_flux1.py | FLUX-specific rollout/training loop |
| SD3.5 rollout/training loop | train_scripts/sol_rl/train_sd3.py | SD3-specific rollout/training loop |
| Reward models | diffusion/post_training/rewards.py | HPSv2Scorer, ClipScorer, PickScoreScorer, ImageRewardScorer, multi_score() |
4. Experimental Setup (实验设置)
- Models / hardware:实验覆盖 SANA-1.5 1600M、FLUX.1-dev、Stable Diffusion 3.5-Large;全部实验使用 8 张 NVIDIA B200 GPU;NVFP4 backend 使用 NVIDIA Transformer Engine。
- Datasets / prompts:训练 prompts 从 PickScore training split 采样;评估使用 held-out subset。论文未给出 prompt 总样本数,只给出训练调度中每 epoch 48 prompts。
- Reward objectives / evaluation metrics:ImageReward(BLIP-based human preference model)、CLIPScore(CLIP text-image embedding cosine similarity)、PickScore(Pick-a-Pic pairwise preference model)、HPSv2(fine-tuned CLIP-based human preference scorer)。训练时每个 reward model 独立作为 alignment objective,评估时在 held-out set 上用四个指标全部评估。
- Baselines:DanceGRPO、FlowGRPO、AWM、DiffusionNFT;额外效率/消融对比包括 DiffusionNFT 24-in-24、naive BF16 scaling 24-in-96、compiled BF16 rollout、direct NVFP4 quantized rollout、Sol-RL two-stage rollout。
关键训练配置如下:
| Hyperparameter | SANA-1.5 1600M | FLUX.1-dev | SD3.5-Large |
|---|---|---|---|
| Image resolution | |||
| Gradient checkpointing | off | on | off |
| LoRA target modules | to_{q,k,v,out} | to_{q,k,v,out} | attn.{to,add}_{q,k,v,out} |
| LoRA rank / alpha / init | 32 / 64 / Gaussian | 32 / 64 / Gaussian | 32 / 64 / Gaussian |
| Optimizer | AdamW | AdamW | AdamW |
| Learning rate | |||
| Weight decay / epsilon | / | same | same |
| Mixed precision | BF16 | BF16 | BF16 |
| ODE solver | Euler (flow) | DPM-Solver-2 | DPM-Solver-2 |
| Rollout steps / eval steps | 10 / 40 | 10 / 28 | 10 / 40 |
| Per-GPU micro-batch | 16 | 12 | 4 |
| Gradient accumulation steps | 9 | 12 | 36 |
| Timestep fraction / train timesteps | 0.6 / 6 | 0.4 / 4 | 0.6 / 6 |
| Max gradient norm | 1.0 | 1.0 | 0.002 |
| Loss guidance parameter | 1.0 | 1.0 | 1.0 |
| KL penalty | same | same | |
| Advantage clip | 5 | 5 | 5 |
| Prompts per epoch / GPUs | 48 / 8 | 48 / 8 | 48 / 8 |
| Best-of- / images per prompt | 24 / 96 | 24 / 96 | 24 / 96 |
| Exploration steps / model | 6 / Compiled + NVFP4 | same | same |
| Full rollout model | Compiled BF16 | Compiled BF16 | Compiled BF16 |
| EMA decay | 0.9 | 0.9 | 0.9 |
| Old-model decay | linear ramp, rate 0.001, cap 0.5 | same | same |
5. Experimental Results (实验结果)
5.1 Main quantitative results
主表在相同 GPU-hour budget 下评估 FLUX.1。Base w/o CFG 分数为 ImageReward 0.455、CLIPScore 0.2630、PickScore 0.8096、HPSv2 0.2566。
| Method | ImageReward | Δ | CLIPScore | Δ | PickScore | Δ | HPSv2 | Δ |
|---|---|---|---|---|---|---|---|---|
| DanceGRPO | 1.4937 | +1.0387 | 0.2898 | +0.0268 | 0.8807 | +0.0711 | 0.3552 | +0.0986 |
| FlowGRPO | 1.5331 | +1.0781 | 0.2884 | +0.0254 | 0.8743 | +0.0647 | 0.3501 | +0.0935 |
| AWM | 1.6693 | +1.2143 | 0.3039 | +0.0409 | 0.8842 | +0.0746 | 0.3664 | +0.1098 |
| DiffusionNFT | 1.6707 | +1.2157 | 0.2991 | +0.0361 | 0.8852 | +0.0756 | 0.3613 | +0.1047 |
| Sol-RL | 1.7636 | +1.3086 | 0.3089 | +0.0459 | 0.8932 | +0.0836 | 0.3688 | +0.1122 |
Figure 4 解读:该图横轴是 GPU Hours,纵轴是不同 reward metric 下的 alignment score;绿色 Sol-RL 在 SANA、FLUX.1、SD3.5-L 与多种 reward function 组合上均优于灰色 DiffusionNFT,说明 two-stage rollout 不只是单一模型/单一 reward 上的工程加速。
5.2 Efficiency and preservation
| Base Model | Rollout Time Naive | Rollout Time Ours | Rollout Speedup | End-to-End Naive | End-to-End Ours | End-to-End Speedup |
|---|---|---|---|---|---|---|
| FLUX.1 | 184s | 79s | 2.33× | 274s | 169s | 1.62× |
| SD3.5-Large | 451s | 187s | 2.41× | 691s | 427s | 1.61× |
| SANA | 65s | 46s | 1.41× | 95s | 76s | 1.25× |
| Base Model | HPSv2 Naive | HPSv2 Sol-RL |
|---|---|---|
| FLUX.1 | 0.3699 | 0.3688 (-0.29%) |
| SD3.5-Large | 0.3803 | 0.3762 (-1.08%) |
| SANA | 0.3682 | 0.3686 (+0.11%) |
结果说明 Sol-RL 的主要收益来自 rollout generation overhead 的降低;同时在相同训练步数下,HPSv2 与 naive BF16 scaling 的差距约在 1% 内,说明 BF16 regeneration 基本保留了 naive scaling 的 alignment quality。
5.3 Ablations
| Exploration steps | HPSv2 |
|---|---|
| 2 | 0.3587 |
| 4 | 0.3650 |
| 6 | 0.3686 |
| 8 | 0.3659 |
| Exploration pool size | HPSv2 |
|---|---|
| 24 | 0.3569 |
| 48 | 0.3622 |
| 72 | 0.3663 |
| 96 | 0.3686 |
消融显示:探索步数从 2 到 6 提升明显,但 8 steps 反而略降,说明低精度 proxy 不需要完整采样即可提供有效排序;pool size 从 24 到 96 单调提升,支持“大候选池 + 选择性训练”的核心假设。
| Base Model | IS BF16 | IS NVFP4 | CLIP BF16 | CLIP NVFP4 |
|---|---|---|---|---|
| FLUX.1 | 16.84 | 17.85 | 27.44 | 27.10 |
| SANA | 16.02 | 15.94 | 29.53 | 29.43 |
| SD3.5-Large | 16.42 | 17.60 | 28.37 | 28.34 |
这组量化质量结果表明 NVFP4 rollout 在 IS / CLIP 上与 BF16 接近,语义结构没有明显崩坏;但这并不等价于可直接用于训练,因为 Figure 3b 已显示 direct quantized training target 会带来优化退化。
5.4 Ranking consistency analysis
| Reward Metric | Kendall | Spearman | Top-4 Match | Bottom-4 false inclusion | Top-8 Match | Bottom-8 false inclusion | Top-12 Match | Bottom-12 false inclusion |
|---|---|---|---|---|---|---|---|---|
| CLIPScore | 0.752 | 0.900 | 95.7% | 4.5% | 93.9% | 6.2% | 92.2% | 8.2% |
| HPSv2 | 0.827 | 0.943 | 97.6% | 3.4% | 95.5% | 5.3% | 93.9% | 7.1% |
| ImageReward | 0.807 | 0.932 | 97.2% | 3.9% | 95.1% | 5.9% | 93.4% | 7.6% |
| PickScore | 0.806 | 0.934 | 97.1% | 3.8% | 95.4% | 5.6% | 93.6% | 7.2% |
| Overall | 0.798 | 0.927 | 96.9% | 3.9% | 95.0% | 5.7% | 93.3% | 7.5% |
这解释了为什么 Sol-RL 选择 top/bottom seeds 而不是随机采样:FP4 proxy 的全局相关性高,而且在最重要的极端候选上 Top-4 match 达到 96.9%,Bottom-4 false inclusion 只有 3.9%。这些极端样本提供最强 positive / negative advantage signal。
Figure 5 解读:该图可视化 NVFP4 与 BF16 rollout 的差异,显示低精度样本可能有局部偏差,但整体 semantic layout 与结构大体保留;这与 ranking consistency 表相互印证。
Figure 6 解读:上排是 SANA base model,下排是经过 Sol-RL 在 HPSv2、PickScore、CLIPScore、OCR 等 reward 下优化后的输出;改进主要体现在复杂细节、文本/语义对齐和 prompt-specific style consistency。
Figure 7 解读:该图比较 FLUX.1-dev base、Sol-RL、DiffusionNFT、FlowGRPO 在 PickScore-optimized setting 下的图像;Sol-RL 更稳定地产生 prompt 对齐、细节充分且风格连贯的样本。
Figure 8 解读:该图在 HPSv2-optimized setting 下做定性比较,显示 Sol-RL 在人类偏好相关的整体视觉质量与语义匹配上优于 DiffusionNFT / FlowGRPO。
Figure 9 解读:该图在 ImageReward-optimized setting 下比较不同 fine-tuning 方法,Sol-RL 的结果通常具有更好的主体完整性、局部细节和 prompt fidelity。
5.5 Limitations and conclusion
论文作者没有单独列出 limitations section;从方法边界看,Sol-RL 依赖两个条件:第一,低精度 rollout 必须保留 group-relative ranking;第二,reward model 的 top/bottom selection 要能提供有用 advantage signal。若某个模型/任务中 FP4 扰动破坏排序,或 reward model 本身不可靠,two-stage filter 的收益会下降。总体结论是:Sol-RL 证明了低精度不必直接进入训练目标,也可以作为大规模探索器服务于高精度 Diffusion RL,从而在不明显牺牲 alignment quality 的情况下显著降低 rollout 成本。