1. Motivation (研究动机)

  • 当前 Diffusion RL post-training 已能通过 GRPO / DiffusionNFT / FlowGRPO 等方法提升 text-to-image 模型的人类偏好对齐,但它们的有效更新高度依赖 rollout group 内候选样本的相对奖励差异;扩大 rollout group size 往往能带来更强的 alignment gain。
  • 瓶颈在于大模型 rollout 极贵:例如 FLUX.1-12B 这类基础模型要为每个 prompt 生成大量候选,BF16 naive scaling 会把绝大多数最终不会用于训练的候选也完整高精度采样,导致 wall-clock 与 GPU-hour 成本成为主要限制。
  • 本文要解决的问题是:能否用低精度 NVFP4 先大规模便宜探索候选空间,再只把最有训练价值的高对比样本用 BF16 重生成用于优化,从而同时保留 rollout scaling 的算法收益和高精度训练的稳定性。解决后,Diffusion RL 可以把更多预算花在“找更有信息量的样本”上,而不是浪费在训练不会用到的候选图像上。

2. Idea (核心思想)

  • 核心 insight:FP4 rollout 不适合直接作为训练 target,因为量化误差会污染梯度和 policy update;但它足够保留同一 prompt 内候选样本的相对 reward ranking,因此可以作为高吞吐的 proxy filter。
  • Sol-RL 把 rollout 拆成两个阶段:Stage 1 用 NVFP4 compiled model 以少步数生成大候选池并排序;Stage 2 只保留 top/bottom 高对比 seed,用 BF16 full rollout 重生成高保真训练样本,再用 DiffusionNFT objective 更新 policy。
  • 与 DiffusionNFT / FlowGRPO 的主要差异不在 reward objective 本身,而在 rollout 数据生产路径:它们通常直接在训练精度下采样用于优化的 group;Sol-RL 则把“探索候选”和“生成训练 target”解耦,用低精度承担搜索、高精度承担优化。

3. Method (方法)

Overall framework. Sol-RL 的整体流程是“FP4 explore, BF16 train”:每轮训练先把当前 policy 的 LoRA/权重同步到 NVFP4 inference engine,针对每个 prompt 采样 个独立 initial noise,用 个 denoising steps 做快速候选生成并用 reward model 打分;然后保留 top- 与 bottom- 的 seed,其中 ;最后用 BF16 compiled model 和完整 rollout steps 从这些 seed 重生成样本,用 DiffusionNFT loss 做 policy update。

Figure 2 解读:图中左侧是低精度 exploration branch,NVFP4 model 负责便宜地产生大候选池;中间的 reward sorting 只传递 seed / ranking 信息,而不是把 FP4 图像直接用于训练;右侧 BF16 branch 根据被选中的 top/bottom seed 重生成高保真 rollout,再进入 DiffusionNFT optimization。这个结构的关键是让量化误差只影响候选筛选,不进入最终梯度 target。

Figure 1 解读:左侧展示 FLUX.1 与 SANA 经 Sol-RL fine-tuning 后的生成样例,强调该方法不只是加速,还能保持高质量视觉输出;右侧训练曲线显示在 ImageReward 上 Sol-RL 能以更少 wall-clock 时间达到 baseline 同等 reward,最高报告 convergence speedup,并收敛到更高 reward ceiling。

3.1 Preliminaries: Diffusion RL 与 GRPO

最简单的 REINFORCE policy gradient 写作:

加入 prompt-level baseline 后为:

GRPO 对每个 prompt 生成 个候选 ,用 group 内 reward 标准化得到 advantage:

对应 PPO-style clipped surrogate objective 为:

3.2 FP4 quantization 与直接量化 rollout 的问题

FP4 把高精度 tensor 通过 group shared scale 映射到 4-bit value,论文写作:

其中 是 shared scaling factor, 是投影函数。OCP MXFP4 使用 32-element group + E8M0 scale;NVIDIA NVFP4 使用 16-element group + E4M3 scale。FP4 的吞吐优势明显,但直接把 FP4 rollout 放进 RL training pipeline 会让训练不稳定,因为优化阶段依赖的是图像/trajectory 的高保真 target,而不是只依赖相对排序。

[Large figure omitted: fig3_group.svg exceeded Cloudflare Pages single-file limit.]

Figure 3a–3c 解读:这组子图在原始 arXiv 源码中属于同一个组合图,应一起看:Figure 3a 比较不同 -in- precision setting 的 iteration time breakdown,说明 naive BF16 scaling 的主要成本集中在生成大量候选上,而 FP4 rollout 能显著降低 generation time;Figure 3b 展示 direct FP4 quantized rollout 相对 BF16 baseline 的训练性能退化和不稳定,支撑“低精度可以加速 sampling,但不能无条件替代训练 target”的负面结论;Figure 3c 显示 FP4 与 BF16 在同一 group 内 reward ranking 的分布接近对角线,说明即使绝对 reward 有扰动,候选之间的相对顺序仍高度一致,Sol-RL 正是利用这个“排序可靠、训练不可靠”的分离性质。

3.3 Two-stage rollout 的直觉

Sol-RL 的直觉类似先用低成本检索器从大库中找 hard positives / hard negatives,再用高质量生成器重建训练样本。GRPO/DiffusionNFT 需要的是 group 内有足够 reward spread 的样本来形成强 advantage signal,而不是需要所有候选都作为训练 target;因此,先用 NVFP4 找到 top/bottom seed,再用 BF16 重生成这些 seed,能把探索预算放大到 ,同时把优化实际处理的样本数控制在 。这种设计把 FP4 误差限制在“选择哪些 seed”这一离散决策中,而不是让误差连续进入 loss。

3.4 Theoretical justification: ranking preservation

附录把低精度 rollout 建模为 ODE vector field 的有界扰动。高精度 trajectory 满足:

低精度 accelerated trajectory 满足:

其中 是 FP4 rounding 与低精度 solver arithmetic 的有效扰动。若 -Lipschitz,则 Grönwall inequality 给出最终 sample deviation bound:

若 reward model -Lipschitz,则 fixed seed 的 absolute reward error 也被上式诱导的 sample deviation 控制。论文进一步用 Extreme Value Theory 论证:当 group size 增大且 top/bottom 样本的 reward margin 足够大时,小扰动不容易改变极端候选集合,因此 FP4 适合作为 top/bottom seed filter。

3.5 Pseudocode: NVFP4 exploration filter

import torch
 
@torch.no_grad()
def fp4_exploration_filter(policy, prompts, reward_fn, *, n=96, k=24, steps=6):
    fp4_policy = quantize_to_nvfp4_compiled(policy)
    selected = []
 
    for prompt in prompts:
        seeds = torch.randint(0, 2**31 - 1, (n,), device="cuda")
        images = []
        for seed in seeds:
            noise = make_initial_noise(seed, prompt)
            image = fp4_policy.sample(prompt, noise=noise, num_inference_steps=steps)
            images.append(image)
 
        rewards = reward_fn(images, [prompt] * n)
        order = torch.argsort(rewards)
        bottom = order[: k // 2]
        top = order[-k // 2 :]
        keep = torch.cat([bottom, top])
        selected.append({"prompt": prompt, "seeds": seeds[keep], "proxy_rewards": rewards[keep]})
 
    return selected

3.6 Pseudocode: BF16 regeneration and reward/advantage construction

import torch
 
@torch.no_grad()
def bf16_regenerate_rollouts(policy, selected_groups, reward_fn, *, steps=10):
    batch = []
    for group in selected_groups:
        prompt = group["prompt"]
        images, trajectories = [], []
        for seed in group["seeds"]:
            noise = make_initial_noise(seed, prompt)
            image, traj = policy.sample_with_logprob(
                prompt,
                noise=noise,
                num_inference_steps=steps,
                dtype=torch.bfloat16,
            )
            images.append(image)
            trajectories.append(traj)
 
        rewards = reward_fn(images, [prompt] * len(images))
        advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-4)
        batch.append({"prompt": prompt, "trajectories": trajectories, "advantages": advantages})
    return collate_rollout_batch(batch)

3.7 Pseudocode: DiffusionNFT-style policy update

import torch
import torch.nn.functional as F
 
 
def diffusion_nft_update(policy, ref_policy, old_policy, rollout_batch, optimizer, *, beta=1e-4, adv_clip=5.0):
    optimizer.zero_grad()
    total_loss = 0.0
 
    for batch in rollout_batch:
        x0 = batch["x0"]
        timesteps = batch["timesteps"]
        cond = batch["prompt_embeds"]
        advantages = torch.clamp(batch["advantages"], -adv_clip, adv_clip)
        r = torch.clamp((advantages / adv_clip) / 2.0 + 0.5, 0.0, 1.0)
 
        pred = policy(x0, timesteps, cond)
        old_pred = old_policy(x0, timesteps, cond).detach()
        ref_pred = ref_policy(x0, timesteps, cond).detach()
 
        positive_loss = F.mse_loss(pred, x0, reduction="none").flatten(1).mean(dim=1)
        negative_loss = F.mse_loss(2 * old_pred - pred, x0, reduction="none").flatten(1).mean(dim=1)
        policy_loss = (r * positive_loss + (1.0 - r) * negative_loss).mean()
        kl_loss = F.mse_loss(pred, ref_pred)
        total_loss = total_loss + policy_loss + beta * kl_loss
 
    total_loss.backward()
    optimizer.step()
    return total_loss.detach()

3.8 Pseudocode: LoRA synchronization back to compiled inference models

@torch.no_grad()
def sync_after_update(trainable_lora, inference_models, *, adapter_name="old"):
    old_lora = ema_or_linear_ramp_update(trainable_lora)
    for model in inference_models.values():
        base = unwrap_compiled(model)
        sync_lora_to_inference(old_lora, base, adapter_name=adapter_name)
        if model.requires_nvfp4:
            re_quantize_in_place(base)

Code reference: main @ 1bfb9352 (2026-04-14) — pseudocode and mapping based on this commit

Paper ConceptSource FileKey Class/Function
Sol-RL run recipes and launch commandsdocs/sol_rl.mdtraining commands for SANA / FLUX.1 / SD3.5-L
Shared RL config defaultsconfigs/sol_rl/base.pyget_config()
SANA experiment variantsconfigs/sol_rl/sana.py_build_reward(), sana_sol_rl_*(), sana_naive_scaling_*(), sana_naive_quant_*()
FLUX.1 experiment variantsconfigs/sol_rl/flux1.pyflux1_sol_rl_*(), flux1_compile_*(), flux1_naive_quant_*()
SD3.5 experiment variantsconfigs/sol_rl/sd3.pysd3_sol_rl_*(), sd3_compile_*(), sd3_naive_quant_*()
NVFP4 support and model synctrain_scripts/sol_rl/train_utils.pyNVFP4_RECIPE, ensure_transformer_engine_available(), wrap_forward_with_fp8(), replace_linear_with_te(), sync_lora_to_inference()
SANA rollout/training looptrain_scripts/sol_rl/train_sana.py_rollout_for_one_prompt(), stat_tracker.update(), policy loss block, sync_lora_to_inference()
FLUX.1 rollout/training looptrain_scripts/sol_rl/train_flux1.pyFLUX-specific rollout/training loop
SD3.5 rollout/training looptrain_scripts/sol_rl/train_sd3.pySD3-specific rollout/training loop
Reward modelsdiffusion/post_training/rewards.pyHPSv2Scorer, ClipScorer, PickScoreScorer, ImageRewardScorer, multi_score()

4. Experimental Setup (实验设置)

  • Models / hardware:实验覆盖 SANA-1.5 1600M、FLUX.1-dev、Stable Diffusion 3.5-Large;全部实验使用 8 张 NVIDIA B200 GPU;NVFP4 backend 使用 NVIDIA Transformer Engine。
  • Datasets / prompts:训练 prompts 从 PickScore training split 采样;评估使用 held-out subset。论文未给出 prompt 总样本数,只给出训练调度中每 epoch 48 prompts。
  • Reward objectives / evaluation metrics:ImageReward(BLIP-based human preference model)、CLIPScore(CLIP text-image embedding cosine similarity)、PickScore(Pick-a-Pic pairwise preference model)、HPSv2(fine-tuned CLIP-based human preference scorer)。训练时每个 reward model 独立作为 alignment objective,评估时在 held-out set 上用四个指标全部评估。
  • Baselines:DanceGRPO、FlowGRPO、AWM、DiffusionNFT;额外效率/消融对比包括 DiffusionNFT 24-in-24、naive BF16 scaling 24-in-96、compiled BF16 rollout、direct NVFP4 quantized rollout、Sol-RL two-stage rollout。

关键训练配置如下:

HyperparameterSANA-1.5 1600MFLUX.1-devSD3.5-Large
Image resolution
Gradient checkpointingoffonoff
LoRA target modulesto_{q,k,v,out}to_{q,k,v,out}attn.{to,add}_{q,k,v,out}
LoRA rank / alpha / init32 / 64 / Gaussian32 / 64 / Gaussian32 / 64 / Gaussian
OptimizerAdamWAdamWAdamW
Learning rate
Weight decay / epsilon / samesame
Mixed precisionBF16BF16BF16
ODE solverEuler (flow)DPM-Solver-2DPM-Solver-2
Rollout steps / eval steps10 / 4010 / 2810 / 40
Per-GPU micro-batch16124
Gradient accumulation steps91236
Timestep fraction / train timesteps0.6 / 60.4 / 40.6 / 6
Max gradient norm1.01.00.002
Loss guidance parameter 1.01.01.0
KL penalty samesame
Advantage clip555
Prompts per epoch / GPUs48 / 848 / 848 / 8
Best-of- / images per prompt 24 / 9624 / 9624 / 96
Exploration steps / model6 / Compiled + NVFP4samesame
Full rollout modelCompiled BF16Compiled BF16Compiled BF16
EMA decay0.90.90.9
Old-model decaylinear ramp, rate 0.001, cap 0.5samesame

5. Experimental Results (实验结果)

5.1 Main quantitative results

主表在相同 GPU-hour budget 下评估 FLUX.1。Base w/o CFG 分数为 ImageReward 0.455、CLIPScore 0.2630、PickScore 0.8096、HPSv2 0.2566。

MethodImageRewardΔCLIPScoreΔPickScoreΔHPSv2Δ
DanceGRPO1.4937+1.03870.2898+0.02680.8807+0.07110.3552+0.0986
FlowGRPO1.5331+1.07810.2884+0.02540.8743+0.06470.3501+0.0935
AWM1.6693+1.21430.3039+0.04090.8842+0.07460.3664+0.1098
DiffusionNFT1.6707+1.21570.2991+0.03610.8852+0.07560.3613+0.1047
Sol-RL1.7636+1.30860.3089+0.04590.8932+0.08360.3688+0.1122

Figure 4 解读:该图横轴是 GPU Hours,纵轴是不同 reward metric 下的 alignment score;绿色 Sol-RL 在 SANA、FLUX.1、SD3.5-L 与多种 reward function 组合上均优于灰色 DiffusionNFT,说明 two-stage rollout 不只是单一模型/单一 reward 上的工程加速。

5.2 Efficiency and preservation

Base ModelRollout Time NaiveRollout Time OursRollout SpeedupEnd-to-End NaiveEnd-to-End OursEnd-to-End Speedup
FLUX.1184s79s2.33×274s169s1.62×
SD3.5-Large451s187s2.41×691s427s1.61×
SANA65s46s1.41×95s76s1.25×
Base ModelHPSv2 NaiveHPSv2 Sol-RL
FLUX.10.36990.3688 (-0.29%)
SD3.5-Large0.38030.3762 (-1.08%)
SANA0.36820.3686 (+0.11%)

结果说明 Sol-RL 的主要收益来自 rollout generation overhead 的降低;同时在相同训练步数下,HPSv2 与 naive BF16 scaling 的差距约在 1% 内,说明 BF16 regeneration 基本保留了 naive scaling 的 alignment quality。

5.3 Ablations

Exploration steps HPSv2
20.3587
40.3650
60.3686
80.3659
Exploration pool size HPSv2
240.3569
480.3622
720.3663
960.3686

消融显示:探索步数从 2 到 6 提升明显,但 8 steps 反而略降,说明低精度 proxy 不需要完整采样即可提供有效排序;pool size 从 24 到 96 单调提升,支持“大候选池 + 选择性训练”的核心假设。

Base ModelIS BF16IS NVFP4CLIP BF16CLIP NVFP4
FLUX.116.8417.8527.4427.10
SANA16.0215.9429.5329.43
SD3.5-Large16.4217.6028.3728.34

这组量化质量结果表明 NVFP4 rollout 在 IS / CLIP 上与 BF16 接近,语义结构没有明显崩坏;但这并不等价于可直接用于训练,因为 Figure 3b 已显示 direct quantized training target 会带来优化退化。

5.4 Ranking consistency analysis

Reward MetricKendall Spearman Top-4 MatchBottom-4 false inclusionTop-8 MatchBottom-8 false inclusionTop-12 MatchBottom-12 false inclusion
CLIPScore0.7520.90095.7%4.5%93.9%6.2%92.2%8.2%
HPSv20.8270.94397.6%3.4%95.5%5.3%93.9%7.1%
ImageReward0.8070.93297.2%3.9%95.1%5.9%93.4%7.6%
PickScore0.8060.93497.1%3.8%95.4%5.6%93.6%7.2%
Overall0.7980.92796.9%3.9%95.0%5.7%93.3%7.5%

这解释了为什么 Sol-RL 选择 top/bottom seeds 而不是随机采样:FP4 proxy 的全局相关性高,而且在最重要的极端候选上 Top-4 match 达到 96.9%,Bottom-4 false inclusion 只有 3.9%。这些极端样本提供最强 positive / negative advantage signal。

Figure 5 解读:该图可视化 NVFP4 与 BF16 rollout 的差异,显示低精度样本可能有局部偏差,但整体 semantic layout 与结构大体保留;这与 ranking consistency 表相互印证。

Figure 6 解读:上排是 SANA base model,下排是经过 Sol-RL 在 HPSv2、PickScore、CLIPScore、OCR 等 reward 下优化后的输出;改进主要体现在复杂细节、文本/语义对齐和 prompt-specific style consistency。

Figure 7 解读:该图比较 FLUX.1-dev base、Sol-RL、DiffusionNFT、FlowGRPO 在 PickScore-optimized setting 下的图像;Sol-RL 更稳定地产生 prompt 对齐、细节充分且风格连贯的样本。

Figure 8 解读:该图在 HPSv2-optimized setting 下做定性比较,显示 Sol-RL 在人类偏好相关的整体视觉质量与语义匹配上优于 DiffusionNFT / FlowGRPO。

Figure 9 解读:该图在 ImageReward-optimized setting 下比较不同 fine-tuning 方法,Sol-RL 的结果通常具有更好的主体完整性、局部细节和 prompt fidelity。

5.5 Limitations and conclusion

论文作者没有单独列出 limitations section;从方法边界看,Sol-RL 依赖两个条件:第一,低精度 rollout 必须保留 group-relative ranking;第二,reward model 的 top/bottom selection 要能提供有用 advantage signal。若某个模型/任务中 FP4 扰动破坏排序,或 reward model 本身不可靠,two-stage filter 的收益会下降。总体结论是:Sol-RL 证明了低精度不必直接进入训练目标,也可以作为大规模探索器服务于高精度 Diffusion RL,从而在不明显牺牲 alignment quality 的情况下显著降低 rollout 成本。