MARBLE: Multi-Aspect Reward Balance for Diffusion RL
Paper: arXiv:2605.06507 Code: aim-uofa/MARBLE Code reference:
main@b509616c(2026-05-08)
1. Motivation (研究动机)
现有 diffusion RL alignment 已经能用单一 reward 提升 aesthetic、text-image alignment 或 compositional correctness,但真实图像质量是多维的:同一张图既要好看、贴合 prompt,又要会写字、会数数、会处理物体属性和空间关系。论文指出三类常见做法都有结构性问题:训练 one specialist model per reward 不能得到统一模型;weighted-sum reward 会把不同 reward 的监督压成一个标量;sequential fine-tuning 依赖手工 stage schedule,还可能遗忘前面 reward。
本文要解决的具体问题是:在一个 diffusion policy 中同时优化多个 reward,且不需要手工 reward 权重或手工阶段式 curriculum。这个问题值得做,是因为它把“每个指标一个 LoRA/模型”推进到“一个模型覆盖多维质量”,也让 OCR、GenEval 这类 specialist reward 不再被 aesthetic / preference 这类 general reward 淹没。
核心 failure mode 是 paper 称为 specialist sample 的 sample-level mismatch:很多 rollout 只对少数 reward 有信息量。例如一张没有文字的猫图对 OCR reward 几乎无意义;但在 weighted sum 里,它仍会被 OCR 维度参与平均,导致真正有用的 reward advantage 被稀释。论文还用 gradient cosine 诊断显示 weighted-sum update 在 80% mini-batches 里对某个 reward 是负对齐。
2. Idea (核心思想)
MARBLE 的核心 insight 是:multi-reward diffusion RL 的冲突不应该在 reward scalar 上解决,而应该保留每个 reward 的 advantage / gradient,再在 gradient space 中找一个同时不伤害各 reward 的 shared update direction。它把“reward 加权”改成“per-reward policy gradient harmonization”。
关键创新有三点:第一,每个 reward 独立维护 advantage estimator,避免 specialist samples 被无关 reward 稀释;第二,用 normalized per-reward gradients 解一个 simplex QP,得到最小范数 convex combination;第三,利用 DiffusionNFT loss 对 advantage 的 affine 结构,把 次 reward backward amortize 成接近单 reward 的成本,并用 EMA 平滑 。
和现有方法的根本差别:FlowGRPO-style specialists 是“一个 reward 一个模型”;DiffusionNFT sequential 是“一个模型但手写 reward stage”;DiffusionNFT simultaneous 是“先把 reward 加成标量再训练”。MARBLE 不做 scalar reward aggregation,而是在同一 batch 上先得到 ,再学习 update direction 。
3. Method (方法)
代码状态。 论文给出 GitHub repo,但 main@b509616c 只有 README.md,README 状态写明 “Code release in progress; Inference + checkpoints first, training code later”。因此:代码搜索未找到开源实现。下面的伪代码是根据论文公式、Algorithm 1 和 arXiv source 写出的 paper-faithful reference pseudocode,不是 released implementation;training-config 数字也不是从 launch script / experiment config 抽取,因为当前 released repo 没有这些文件。
3.1 Overall framework
Figure 1 解读:左侧是 one-model-per-reward 的 specialist 方案,虽然每个 reward 可单独优化,但不能形成统一模型;中间是 sequential multi-reward training,需要人工决定 reward 顺序和每阶段步数;右侧是 MARBLE 的目标:同一个 diffusion model 在同一训练过程中接收多个 reward 的监督。
Figure 2 解读:每列是一个 sample,每行是一个 reward 的 z-score advantage。高 advantage 往往集中在 OCR、GenEval 等 source-specific reward 上,很少有样本在所有 reward 上同时为正。这说明 weighted sum 会把“只对某一维有用”的样本监督摊薄,是 MARBLE 选择 per-reward advantage 的直接证据。
Figure 3 解读:MARBLE 对同一 prompt batch 生成 images,分别由 个 reward models 打分;每个 reward 产生独立 loss 和 gradient ;gradient harmonization 解出 balancing coefficients ,形成共享 update direction ,再加上 KL term 更新 shared model 。
直觉上,MARBLE 把“这张图在哪个维度值得学”保留下来,而不是先把所有维度求和。对 OCR 来说,只有含文字或文字质量有差异的 samples 真的 informative;对 GenEval 来说,属性、计数、空间关系样本更 informative。per-reward advantage 让这些 specialist signals 不被其他 reward 的噪声抵消;gradient harmonization 再把多个方向合成一个不会明显反向伤害某个 reward 的更新。
3.2 DiffusionNFT base objective
单 reward diffusion RL 优化:
DiffusionNFT 用 NFT loss 实现 advantage-guided update。对 generated sample 和 timestep :
其中 ,,,。关键结构是: 和 与 advantage value 无关,advantage 只通过 affine map 改变 。
3.3 Per-reward advantage decomposition
对 个 rewards ,MARBLE 不先做 ,而是为每个 reward 维护 prompt-conditioned advantage estimator:
每个 产生自己的 和 NFT loss ,同一 batch 上得到 per-reward policy gradient:
这一步让样本只在自己 informative 的 reward dimension 上产生强监督,而不是被无关维度平均掉。
3.4 Gradient harmonization
不同 reward model 的 gradient scale 可能差很多,所以先归一化:
然后解一个 convex QP / MGDA-style problem:
得到的方向是 normalized gradients convex hull 中的 minimum-norm point:
因为 来自 unit gradients,MARBLE 再用原始 gradient 平均范数恢复 scale:
最终更新把 reward direction 和 KL decouple:
KL 不参与 harmonization solve,因为 reward gradients 决定“往哪个质量维度改”,KL 决定“离 reference policy 多远”。
Figure 4 解读:图中比较 weighted-sum direction 与 MARBLE harmonized direction 对各 reward gradients 的 cosine。weighted sum 的 worst-reward alignment 经常为负;MARBLE 把 worst-reward cosine 提升到正区间,并显著降低跨 reward 的 alignment variance。
3.5 Amortized gradient harmonization 与 EMA smoothing
Full harmonization 每步需要 backward passes( 个 reward-specific passes 加一个 KL pass),成本随 reward 数增长。论文利用 NFT loss 的 affine property 证明:当 clamp inactive 时,对 advantage 做 convex combination 等价于对 per-reward NFT gradients 做相同 convex combination。
设 ,,且 、,则:
证明依赖 的 affine 形式:
实际训练中每 steps 执行 full harmonization 刷新 ;中间 steps 使用 cached coefficients 构造 ,只做一次 reward backward。论文主实验使用 ,,并报告未观察到 clamp active。
为避免某个 noisy mini-batch 把 specialist reward 的 coefficient 瞬间压到接近 0,MARBLE 对 coefficient 做 EMA:
论文默认 。EMA 的作用不是改变 QP,而是让 amortized window 内的 update 不被短期 rollout failure 主导。
Figure 5 解读:训练曲线显示五个 optimized rewards 都继续提升,说明 MARBLE 不是只牺牲某些维度来换取另一些维度,而是在同一模型中同步改善 general 与 specialist rewards。
Figure 6 解读: 的轨迹不是固定 0.2,也不简单跟随 raw reward value;它反映当前 batch 下各 reward gradient 的冲突与优化难度。pareto_fallback_used 用来记录 Section 3.5 中 clamp / fallback 条件触发情况。
Figure 7 解读:固定 早期能较快提升 HPSv2 这类 broad quality reward,但对 GenEval 这种 harder specialist reward 收敛更慢且终点更低,说明 adaptive coefficients 对难 reward 的持续分配是必要的。
Figure 8 解读: 太小会让 coefficient 对单 batch noise 过度敏感; 太大又会让 coefficients 反应迟钝。论文在 中选择 ,因为它在稳定性和适应性之间最好。
Figure 9 解读:不同 EMA decay 的 qualitative samples 展示了 coefficient stability 对最终图像质量的影响; 的图像更能同时保持 prompt fidelity、文字/计数约束和整体视觉质量。
3.6 Paper-faithful pseudocode
import torch
import torch.nn.functional as F
def nft_loss(v_theta, v_old, v_target, advantage, beta=1.0, amax=5.0):
"""Paper-faithful NFT loss for one reward dimension."""
r = torch.clamp(0.5 + advantage / (2.0 * amax), 0.0, 1.0)
v_pos = (1.0 - beta) * v_old + beta * v_theta
v_neg = (1.0 + beta) * v_old - beta * v_theta
loss_pos = F.mse_loss(v_pos, v_target, reduction="none").flatten(1).mean(1)
loss_neg = F.mse_loss(v_neg, v_target, reduction="none").flatten(1).mean(1)
return (r * loss_pos + (1.0 - r) * loss_neg).mean()
def per_reward_advantages(reward_matrix, prompt_ids, eps=1e-6):
"""Normalize each reward within prompt groups: reward_matrix [B, K]."""
advantages = torch.zeros_like(reward_matrix)
for prompt in torch.unique(prompt_ids):
mask = prompt_ids == prompt
group_rewards = reward_matrix[mask]
mu = group_rewards.mean(dim=0, keepdim=True)
sigma = group_rewards.std(dim=0, keepdim=True, unbiased=False)
advantages[mask] = (group_rewards - mu) / (sigma + eps)
return advantagesdef solve_mgda_alpha(flat_grads, num_steps=200, lr=0.05):
"""Small differentiable simplex QP surrogate for alpha; paper uses convex QP."""
grads = torch.stack(flat_grads, dim=0)
grads = grads / (grads.norm(dim=1, keepdim=True) + 1e-12)
logits = torch.zeros(grads.shape[0], device=grads.device, requires_grad=True)
opt = torch.optim.Adam([logits], lr=lr)
for _ in range(num_steps):
alpha = torch.softmax(logits, dim=0)
direction = (alpha[:, None] * grads).sum(dim=0)
loss = direction.pow(2).sum()
opt.zero_grad()
loss.backward()
opt.step()
return torch.softmax(logits.detach(), dim=0)def marble_reward_step(model, batch, reward_fns, full_harmonize=True, cached_alpha=None):
"""Reference MARBLE reward step; diffusion sampling details are model-specific."""
images, prompt_ids, diffusion_state = batch
rewards = torch.stack([fn(images) for fn in reward_fns], dim=1)
advantages = per_reward_advantages(rewards, prompt_ids)
per_reward_losses, flat_grads = [], []
for k in range(len(reward_fns)):
v_theta, v_old, v_target = model.velocity_terms(diffusion_state)
loss_k = nft_loss(v_theta, v_old, v_target, advantages[:, k])
per_reward_losses.append(loss_k)
if full_harmonize:
model.zero_grad(set_to_none=True)
loss_k.backward(retain_graph=(k + 1 < len(reward_fns)))
flat_grads.append(torch.cat([p.grad.flatten() for p in model.parameters() if p.grad is not None]))
if full_harmonize:
alpha = solve_mgda_alpha(flat_grads)
else:
alpha = cached_alpha
combined_advantage = (advantages * alpha[None, :]).sum(dim=1)
v_theta, v_old, v_target = model.velocity_terms(diffusion_state)
return nft_loss(v_theta, v_old, v_target, combined_advantage), alphadef update_ema_alpha(alpha_star, alpha_ema, rho=0.7):
"""EMA smoothing used during amortized harmonization."""
if alpha_ema is None:
return alpha_star
alpha = rho * alpha_ema + (1.0 - rho) * alpha_star
return alpha / alpha.sum()
def ddp_extract_per_reward_grads(ddp_model, losses):
"""Paper Algorithm 1: no_sync for per-reward backward, then explicit all_reduce."""
flat_grads = []
for k, loss_k in enumerate(losses):
with ddp_model.no_sync():
loss_k.backward(retain_graph=(k + 1 < len(losses)))
flat = torch.cat([p.grad.flatten() for p in ddp_model.parameters() if p.grad is not None])
flat_grads.append(flat.detach().clone())
ddp_model.zero_grad(set_to_none=True)
for g in flat_grads:
torch.distributed.all_reduce(g, op=torch.distributed.ReduceOp.AVG)
return flat_grads论文公式与 released code 实现差异:released repo 在 main@b509616c 尚无训练实现,无法核验公式与实现差异;当前只能确认 README 声明 training code later。若未来代码发布,需要重点核验 QP solver、gradient normalization、 clamp、 amortization、 EMA 是否与论文一致。
Code reference:
main@b509616c(2026-05-08) — released repo contains onlyREADME.md; pseudocode is based on paper equations / arXiv source; the following mapping records released-repo coverage, not executable released code.
| Paper Concept | Source File | Key Class/Function |
|---|---|---|
| Public repository status | README.md | “Code release in progress”; no training source files in commit |
| MARBLE gradient harmonization | Not present in released repo | No released class/function; paper equations in Chapter/3_Method.tex |
| Amortized coefficients and EMA | Not present in released repo | No released class/function; paper equations in Chapter/3_Method.tex |
| DDP per-reward synchronization | Not present in released repo | Algorithm 1 in Appendix/D_Implementation.tex; no code file |
4. Experimental Setup (实验设置)
Datasets / reward sources and scale. 论文明确使用五个 optimized rewards:PickScore、HPSv2、CLIPScore、OCR accuracy、GenEval;并在 evaluation 中额外报告 Aesthetic Score、ImageReward、UniReward。论文未详细说明 training prompt dataset 的名称和样本数,也未给出 GenEval / OCR / PickScore 等评测集的 prompt count;唯一明确的 human-study scale 是每个方法随机采样 30 张生成图、20 名匿名参与者、按 text-image alignment 与 image quality 两轴 1–5 分打分。
Baselines. 预训练/基础模型包括 SD-XL、SD3.5-L、FLUX.1-Dev、SD3.5-M without CFG、SD3.5-M + CFG;RL baselines 包括三条 single-reward FlowGRPO specialists、DiffusionNFT sequential multi-stage training、DiffusionNFT simultaneous five-reward weighted-sum training。
Metrics. Rule-based metrics 是 GenEval(compositional / object correctness)和 OCR(text rendering accuracy);model-based metrics 是 PickScore、CLIPScore、HPSv2.1、Aesthetic、ImageReward、UniReward。Composite 是 Table 1 中每行的 column-wise z-score 平均值,所有 metric 等权。
Training config. Paper-reported config:base model 是 Stable Diffusion 3.5 Medium;fine-tune LoRA adapters with rank 32 and alpha 64;loss 使用 DiffusionNFT NFT loss;optimizer 是 AdamW;constant learning rate ;训练 jointly optimizes 5 rewards;主实验训练使用 16×NVIDIA H200 GPUs;efficiency table 使用 8×H200;MARBLE 默认 、amortization interval 、EMA decay 、。论文未详细说明 total training steps、global batch size、prompt sampling schedule。
5. Experimental Results (实验结果)
5.1 Main quantitative results
| Model | GenEval | OCR | PickScore | CLIPScore | HPSv2.1 | Aesthetic | ImgRwd | UniRwd | Composite |
|---|---|---|---|---|---|---|---|---|---|
| SD-XL | 0.55 | 0.14 | 22.42 | 0.287 | 0.280 | 5.60 | 0.76 | 2.93 | -0.455 |
| SD3.5-L | 0.71 | 0.68 | 22.91 | 0.289 | 0.288 | 5.50 | 0.96 | 3.25 | +0.116 |
| FLUX.1-Dev | 0.66 | 0.59 | 22.84 | 0.295 | 0.274 | 5.71 | 0.96 | 3.27 | +0.104 |
| SD3.5-M (w/o CFG) | 0.24 | 0.12 | 20.51 | 0.237 | 0.204 | 5.13 | -0.58 | 2.02 | -2.319 |
| + CFG | 0.63 | 0.59 | 22.34 | 0.285 | 0.279 | 5.36 | 0.85 | 3.03 | -0.255 |
| + FlowGRPO (GenEval specialist) | 0.95 | 0.66 | 22.51 | 0.293 | 0.274 | 5.32 | 1.06 | 3.18 | +0.120 |
| + FlowGRPO (OCR specialist) | 0.66 | 0.92 | 22.41 | 0.290 | 0.280 | 5.32 | 0.95 | 3.15 | +0.013 |
| + FlowGRPO (PickScore specialist) | 0.54 | 0.68 | 23.50 | 0.280 | 0.316 | 5.90 | 1.29 | 3.37 | +0.362 |
| + DiffusionNFT | 0.94 | 0.91 | 23.80 | 0.293 | 0.331 | 6.01 | 1.49 | 3.49 | +1.015 |
| + DiffusionNFT | 0.92 | 0.91 | 21.53 | 0.267 | 0.300 | 6.15 | 1.16 | 3.04 | +0.184 |
| + MARBLE | 0.94 | 0.96 | 22.83 | 0.286 | 0.355 | 6.59 | 1.53 | 3.52 | +1.116 |
MARBLE 的主要结论是:它没有拿到 PickScore / CLIPScore 的最高值(DiffusionNFT sequential 更高),但在 OCR、HPSv2.1、Aesthetic、ImageReward、UniReward 和 Composite 上最好;Composite +1.116 高于 DiffusionNFT sequential 的 +1.015,说明小幅牺牲部分 preference / CLIP proxy 后,整体多维质量更强。
Figure 10 解读:qualitative comparisons 显示 MARBLE 相比 weighted-sum / sequential baselines 更能同时满足文字、属性、位置、计数等要求;weighted-sum baseline 往往只在部分维度有效,不能稳定覆盖所有 reward dimensions。
Figure 11 解读:额外 qualitative results 强调同一个 MARBLE 模型能同时改善 text rendering、attribute-object binding、spatial layout 和 counting constraints,而不是为每个能力维护不同 specialist checkpoint。
Figure S1 解读:补充 qualitative comparison 继续展示 MARBLE 在复杂 prompt 下的综合能力,重点是多 reward 同时满足,而不是单一 metric 上的局部最优。
Figure S2 解读:补充样例进一步说明 specialist reward(如文字、计数、属性)不会被 general reward 完全压制,生成结果能保留较高可读性和结构一致性。
Figure S3 解读:该组样例用于检查 MARBLE 的跨 prompt robustness;从论文说明看,它支持主表中的结论:单模型能覆盖多个 reward axes。
Figure S4 解读:最后一组补充比较展示更大范围的 qualitative cases,帮助确认 MARBLE 的改善不是少数 cherry-picked samples,而是跨不同 prompt 类型的趋势。
5.2 Efficiency and ablations
| Method | Relative speed | GPU memory |
|---|---|---|
| Weighted Sum (, DiffusionNFT Baseline) | 59G () | |
| MARBLE w/ amortization (, ) | 67G () | |
| MARBLE w/o amortization () | 67G () |
Amortization 是实用性的关键:full harmonization 只有 speed,而 amortized version 达到 ,接近 weighted-sum baseline;显存从 59G 到 67G,即 。
| Variant | GenEval | OCR | PickScore | CLIPScore | HPSv2.1 | Aesthetic | ImgRwd | UniRwd |
|---|---|---|---|---|---|---|---|---|
| MARBLE full () | 0.93 | 0.96 | 22.62 | 0.283 | 0.355 | 6.59 | 1.52 | 3.45 |
| w/o gradient normalization | FAIL | FAIL | FAIL | FAIL | FAIL | FAIL | FAIL | FAIL |
| Fixed | 0.86 | 0.89 | 22.64 | 0.272 | 0.346 | 6.55 | 1.45 | 3.42 |
| Solve every step | 0.92 | 0.92 | 21.32 | 0.267 | 0.301 | 5.89 | 1.17 | 3.04 |
关键 ablation:去掉 gradient normalization 会失败,说明 QP 必须处理 gradient scale;固定 会削弱 GenEval/OCR;每步解 反而差,原因是单 batch coefficient fluctuation 太强且训练成本高。
| EMA decay | GenEval | OCR | PickScore | CLIPScore | HPSv2.1 | Aesthetic | ImgRwd | UniRwd |
|---|---|---|---|---|---|---|---|---|
| 0.86 | 0.80 | 21.52 | 0.261 | 0.292 | 5.84 | 1.22 | 2.98 | |
| 0.88 | 0.84 | 21.76 | 0.266 | 0.312 | 6.03 | 1.27 | 3.04 | |
| 0.93 | 0.95 | 22.02 | 0.276 | 0.340 | 6.14 | 1.48 | 3.43 | |
| 0.94 | 0.96 | 22.83 | 0.286 | 0.355 | 6.59 | 1.53 | 3.52 | |
| 0.90 | 0.89 | 22.14 | 0.272 | 0.342 | 6.26 | 1.47 | 3.40 |
在所有报告指标上最高; 太抖, 太惯性。Alternative heuristics 也弱于 MARBLE:Reward Dropout 为 GenEval 0.80 / OCR 0.61 / HPSv2.1 0.270;Reward Weighting 为 GenEval 0.90 / OCR 0.92 / HPSv2.1 0.285,说明手工上调 specialist reward 权重不能可靠解决冲突。
5.3 Human study, limitations, and conclusion
| Method | Text-image alignment | Image quality |
|---|---|---|
| DiffusionNFT | 3.60 | 2.79 |
| DiffusionNFT | 4.26 | 3.58 |
| MARBLE | 4.63 | 4.41 |
Human study 支持自动指标结论:即使 MARBLE 的 PickScore / CLIPScore 不最高,人类评分仍在 text-image alignment 和 image quality 上最高。论文同时明确不声称该 user study 有 statistical significance,因此它应被看作 automatic metrics 的补充证据。
作者提到的 limitations / future work:当前只研究 5 个 reward dimensions,扩展到更大、更异构、冲突更强的 reward set 仍是未来方向;另一个方向是扩展到 video generation 和 generative world models,因为这些任务还要同时优化 temporal consistency、motion realism、physical plausibility 和 long-horizon dynamics。当前 released repo 未提供训练代码,也限制了复现实验和核验 paper-code consistency。
总体结论:MARBLE 证明 multi-reward diffusion RL 的关键不是调一个更好的 reward weighted sum,而是保留 per-reward supervision,并在 gradient space 中做冲突协调。它在 SD3.5 Medium 的五 reward 设置中同时提升 OCR、GenEval、HPSv2.1、Aesthetic、ImageReward、UniReward 等维度,并通过 amortized harmonization 把训练速度保持在 weighted-sum baseline 的 。