LeapAlign: Post-Training Flow Matching Models at Any Generation Step by Building Two-Step Trajectories
1. Motivation(研究动机)
- 现有方法的问题: Flow matching 采样轨迹是连续可微的,理论上可以把 reward gradient 沿 ODE 采样链直接反传;但完整长轨迹反传会导致显存成本过高和 gradient explosion。已有 direct-gradient 方法通常只更新靠近最终图像的晚期步,或者像 DRTune 一样通过 stop-gradient 丢掉 nested gradient,从而难以有效学习控制全局构图的早期步。
- 本文要解决的问题: 在不展开完整长轨迹反传的情况下,让 reward 对任意生成步(特别是早期步)产生稳定、可负担的梯度,同时保留跨时间步交互信息,而不是完全切断 nested gradient。
- 为什么重要: 早期 denoising / flow steps 决定图像布局、对象关系和组合结构;若只能微调晚期细节,偏好对齐和 GenEval 类 compositional alignment 会受限。LeapAlign 目标是把长轨迹压缩为两步可微 leap,降低显存和梯度风险,使 Flux 等多步 flow model 能更好地做 post-training alignment。
2. Idea(核心思想)
- 核心洞察: 不必沿完整 ODE 轨迹反传;从一次 full rollout 中抽取 两个时刻,用 flow matching 的 one-step leap prediction 构造 的两步可微近似,再用 latent connector 保证前向值等于真实 rollout latent,梯度却只沿 leap prediction 传播。
- 关键创新: (1) Two-step leap trajectory 把反传深度固定为两次 velocity evaluation;(2) gradient discounting 用 缩小 nested gradient,而不是删除;(3) trajectory-similarity weighting 对更接近真实长轨迹的 leap 样本赋更大训练权重。
- 与代表性方法的本质区别: 相比 ReFL / DRaFT-LV 只在一个中间点做 one-step reward backprop,LeapAlign 可以通过随机 覆盖任意时间段;相比 DRTune 的 input stop-gradient,LeapAlign 保留经 缩放的 nested gradient,因此能利用跨步依赖信号。
3. Method(方法)
总体框架
直觉段落。 长轨迹反传的真正难点不是“reward 能否可微”,而是生成链太深导致每个 velocity prediction 的 Jacobian 级联放大。LeapAlign 的做法像给轨迹建一条可微“捷径”:真实 rollout 仍提供 这些可信锚点,但梯度只穿过 和 两次模型调用。这样既避免了完整采样链显存,又让早期 的模型预测能通过第二跳影响最终 ;相似度权重再过滤掉 leap approximation 误差过大的样本,降低错误 credit assignment。
Figure 1 解读: 图中 是完整生成轨迹。蓝色实线是两次 one-step leap prediction: 与 ;橙色虚线是 latent connector,把 leap prediction 的前向值对齐到真实 rollout latent。紫色反传路径只经过两步 leap,因此 reward gradient 可高效流向早期 timestep。
Flow matching 与 one-step leap
Flow matching 以 Gaussian noise 和真实图像 插值得到:
速度场训练目标为:
在 rectified flow 中 ,one-step leap prediction 为:
Two-step leap trajectory + latent connector
随机选择 ,先预测:
再用 straight-through connector 对齐真实 latent:
这使前向值仍为 full rollout 的 ,但梯度通过 leap prediction 走。
Gradient discounting
无 discount 时,两步 leap 的最终梯度包含两类 single-step terms 和一个 nested term:
LeapAlign 不删除 nested term,而把第二跳改成:
对应梯度变为:
Hinge reward objective 与 trajectory-similarity weighting
为避免 reward hacking,论文采用 hinge-style objective:
Leap approximation 的可靠性用两处 connector error 衡量:
主结果总览与定性图

Figure 2 解读: 左图显示 compositional alignment 训练中的 reward improvement;中图比较 Flux 在多 evaluator 上的提升;右图展示 GenEval 雷达图。整体信息是:LeapAlign 不只是提高 in-domain HPSv2.1,也改善 PickScore、UnifiedReward、GenEval 等 out-of-domain / composition 指标。

Figure 3 解读: GenEval 定性图比较 base Flux、ReFL、DRaFT-LV、DRTune 与 LeapAlign。LeapAlign 更常生成满足对象数量、颜色、位置与属性绑定的图像,说明早期 timestep 更新确实影响全局布局。
组件消融图
Figure 4 解读: 在 HPSv2.1 上最佳; 去掉 nested gradient 后仍优于 DRTune,说明 leap 结构本身有用; 过大则 nested gradient 过强,稳定性下降。
Figure 5 解读: 一步 leap 信息不足,三步 leap 显存更高但性能不升;两步在 performance / memory 之间最好。
Figure 6 解读: 使用真实 rollout 得到的 作为 reward 输入优于用 等预测值,说明 reward 应评估最终生成结果,而不是 leap approximation 本身。
Figure 7 解读: 同时使用 和 的 similarity weight 优于只用单个 connector error 或不加权,说明两个锚点都能过滤不可靠 leap。
Figure 8 解读: 在完整 timestep range 随机选择 优于只覆盖局部区间,支持“任意生成步都应可被更新”的设计。
Figure 9 解读: 随机选择 比固定距离更好且更简单,避免只优化固定跨度的局部 dynamics。
Figure 10 解读: 只通过 nested gradient 更新第一跳时, 增大 gradient norm 并伤害性能; 在性能和梯度规模之间取得更好折中,验证 discounting 的必要性。
Figure 11 解读: 该图跟踪 GenEval score 随 fine-tuning 的提升,LeapAlign 相对 ReFL、DRaFT-LV、DRTune 提升更快且最终更高。

Figure 12 解读: 使用 HPSv3 reward 微调后的 Flux 生成样例,展示方法可从 CLIP-style reward 泛化到 VLM-based reward。

Figure 13 解读: 更多 HPSv3 reward 下的定性结果,强调 LeapAlign 不是只对 HPSv2.1 有效。
伪代码(基于开源参考实现)
Code reference: unofficial reference implementation
main@0f706528(2026-04-17) — pseudocode and mapping based on this commit; officialRockeyCoss/leapalignmain@3612c121(2026-04-20) currently contains project page/assets, not training code.
import torch
@torch.no_grad()
def euler_rollout(v_theta, x1, num_steps=25):
traj = [(1.0, x1)]
x = x1
for s in range(num_steps):
t_cur = 1.0 - s / num_steps
t_next = 1.0 - (s + 1) / num_steps
t = torch.full((x.size(0),), t_cur, device=x.device, dtype=x.dtype)
v = v_theta(x, t)
x = x + (t_next - t_cur) * v
traj.append((t_next, x))
return trajimport torch
def build_leap_trajectory(v_theta, x_k, x_j, x_0, k, j, alpha=0.5):
x_k = x_k.detach()
x_j_true = x_j.detach()
x_0_true = x_0.detach()
t_k = torch.full((x_k.size(0),), k, device=x_k.device, dtype=x_k.dtype)
t_j = torch.full((x_j_true.size(0),), j, device=x_j_true.device, dtype=x_j_true.dtype)
v_k = v_theta(x_k, t_k)
x_hat_j_given_k = x_k - (k - j) * v_k
x_j_connected = x_hat_j_given_k + (x_j_true - x_hat_j_given_k).detach()
x_j_for_v = alpha * x_j_connected + (1.0 - alpha) * x_j_connected.detach()
v_j = v_theta(x_j_for_v, t_j)
x_hat_0_given_j = x_j_connected - j * v_j
x_0_connected = x_hat_0_given_j + (x_0_true - x_hat_0_given_j).detach()
return x_hat_j_given_k, x_hat_0_given_j, x_j_connected, x_0_connectedimport torch
import torch.nn.functional as F
def trajectory_similarity_weight(x_j, x_0, x_hat_j_given_k, x_hat_0_given_j, tau=0.1):
with torch.no_grad():
d_j = (x_j - x_hat_j_given_k).abs().flatten(1).mean(dim=1)
d_0 = (x_0 - x_hat_0_given_j).abs().flatten(1).mean(dim=1)
tau_t = torch.as_tensor(tau, device=d_j.device, dtype=d_j.dtype)
return 1.0 / (torch.maximum(d_j, tau_t) + torch.maximum(d_0, tau_t))
def leap_align_loss(x_0_connected, x_j, x_0, x_hat_j_given_k, x_hat_0_given_j, reward_fn, lam=0.55, tau=0.1):
reward = reward_fn(x_0_connected)
raw = F.relu(lam - reward)
w_sim = trajectory_similarity_weight(x_j, x_0, x_hat_j_given_k, x_hat_0_given_j, tau)
return (w_sim * raw).mean()import torch
def leap_align_step(v_theta, x1, reward_fn, num_ode_steps=25, alpha=0.5, lam=0.55, tau=0.1):
traj = euler_rollout(v_theta, x1, num_steps=num_ode_steps)
i_k = torch.randint(1, num_ode_steps + 1, (1,)).item()
i_j = torch.randint(0, i_k, (1,)).item()
k, j = i_k / num_ode_steps, i_j / num_ode_steps
x_k = traj[num_ode_steps - i_k][1]
x_j = traj[num_ode_steps - i_j][1]
x_0 = traj[-1][1]
xhj, xh0, xjc, x0c = build_leap_trajectory(v_theta, x_k, x_j, x_0, k, j, alpha)
return leap_align_loss(x0c, x_j, x_0, xhj, xh0, reward_fn, lam, tau)
def train_update(v_theta, x1, reward_fn, optimizer):
loss = leap_align_step(v_theta, x1, reward_fn)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(v_theta.parameters(), 1.0)
optimizer.step()
return loss.detach()Code-to-paper mapping table:
| Paper Concept | Source File | Key Class/Function |
|---|---|---|
| Full ODE rollout | leap_align.py | euler_rollout |
| Random sampling over rollout grid | leap_align.py | sample_leap_timesteps |
| Eq. (4)–(7) two-step leap + latent connectors | leap_align.py | build_leap_trajectory, LeapTrajectory |
| Eq. (9) gradient discounting | leap_align.py | build_leap_trajectory (x_j_for_v) |
| Eq. (12) trajectory similarity | leap_align.py | trajectory_similarity_weight |
| Eq. (11), (13) hinge reward loss | leap_align.py | leap_align_loss, LeapAlignLoss |
| End-to-end loss construction | leap_align.py | leap_align_step |
| Optimizer backward / grad clip / AdamW / EMA | train_leap_align.py | main training loop |
| Toy runnable training / optimizer / EMA | train_leap_align.py | main, pretrain_flow, pretrain_reward |
4. Experimental Setup(实验设置)
- Datasets: General preference alignment 使用 HPDv2 抽样 50,000 prompts 训练,也使用 MJHQ-30k prompts;HPDv2 测试为 400 prompts × 4 images = 1,600 images,MJHQ-30k 随机留 500 prompts 测试。Compositional alignment 使用 GenEval official scripts 生成的 50,000 prompts 训练,类别比例 Position:Counting:Attribute Binding:Colors:Two Objects:Single Object = 7:5:3:1:1:0;测试按 GenEval 553 prompts × 4 images。
- Baselines: Policy-gradient:DanceGRPO、MixGRPO;Direct-gradient:ReFL、DRaFT-LV、DRTune。论文说明 DRaFT-LV / DRTune 无官方实现,由作者按伪代码复现;ReFL 从官方实现适配 Flux。
- Metrics: HPSv2.1、HPSv3、PickScore、ImageReward 衡量 human preference;UnifiedReward-Alignment 衡量 image-text alignment;UnifiedReward-IQ 衡量 image quality;GenEval Overall 与六个子类(Single Object、Two Objects、Counting、Colors、Position、Attribute Binding)衡量组合生成正确性。
- Training config: Base model 为 FLUX.1-dev;默认 reward HPSv2.1,,;AdamW lr=1e-5,batch size 64,weight decay 1e-4,EMA decay 0.995,;训练 300 iterations on 16 GPUs。训练 rollout:720×720, 25 steps, CFG 3.5;评估:720×720, 50 steps, CFG 3.5。附录补充:HPSv2.1 时 ;PickScore / HPSv3 时 ;HPSv3 使用 lr=8e-6; 对 HPSv2.1 / PickScore / HPSv3 分别为 0.55 / 0.4 / 13.5。
5. Experimental Results(实验结果)
方法能力对比(Table 1)
| Method | Early Steps | Nested Gradient | Leap Trajectory | Multi-Step |
|---|---|---|---|---|
| ReFL | ✗ | ✗ | ✗ | ✗ |
| DRaFT-LV | ✗ | ✗ | ✗ | ✗ |
| DRTune | ✓ | ✗ | ✗ | ✓ |
| LeapAlign | ✓ | ✓ | ✓ | ✓ |
Flux 主表(Table 2)
| Method | HPSv2.1 ↑ | HPSv3 ↑ | PickScore ↑ | UR-Align ↑ | UR-IQ ↑ | ImageReward ↑ | GenEval Overall ↑ |
|---|---|---|---|---|---|---|---|
| Flux | 0.3078 | 13.5020 | 22.7902 | 3.4514 | 3.5708 | 1.0455 | 0.6535 |
| DanceGRPO | 0.3451 | 14.8336 | 23.1186 | 3.4660 | 3.6199 | 1.2347 | 0.6775 |
| MixGRPO† | 0.3692 | 14.7530 | 23.5184 | 3.4393 | 3.6241 | 1.6155 | 0.7232 |
| ReFL‡ | 0.3852 | 15.5127 | 23.6299 | 3.4786 | 3.6870 | 1.3468 | 0.7011 |
| DRaFT-LV* | 0.3859 | 15.3699 | 23.6437 | 3.4868 | 3.6887 | 1.3384 | 0.7024 |
| DRTune* | 0.3882 | 15.5606 | 23.5185 | 3.4793 | 3.6679 | 1.3562 | 0.7101 |
| LeapAlign | 0.4092 | 15.7678 | 23.7137 | 3.4984 | 3.7244 | 1.5104 | 0.7420 |
GenEval 子类(Table 2)
| Method | Single Obj. | Two Obj. | Count | Color | Position | AttrB |
|---|---|---|---|---|---|---|
| Flux | 99.38 | 86.62 | 66.88 | 74.47 | 19.50 | 45.25 |
| MixGRPO† | 99.69 | 93.69 | 80.00 | 80.05 | 24.25 | 56.25 |
| DRTune* | 99.38 | 93.69 | 73.12 | 76.86 | 27.50 | 55.50 |
| LeapAlign | 99.38 | 96.46 | 72.50 | 80.59 | 30.25 | 66.00 |
不同 reward / prompt set(Table 3)
| Method | HPSv2.1 ↑ | PickScore ↑ | HPSv3 ↑ |
|---|---|---|---|
| Flux | 0.3078 | 22.7902 | 10.7624 |
| ReFL‡ | 0.3852 | 25.2373 | 11.7642 |
| DRaFT-LV* | 0.3859 | 24.9596 | 11.2701 |
| DRTune* | 0.3882 | 25.1021 | 12.0023 |
| LeapAlign | 0.4092 | 25.7589 | 12.5855 |
SD3.5-M 泛化(Appendix Table 4)
| Method | HPSv2.1 ↑ | HPSv3 ↑ | PickScore ↑ | UR-Align ↑ | UR-IQ ↑ | ImageReward ↑ |
|---|---|---|---|---|---|---|
| SD3.5-M | 0.2967 | 12.2846 | 22.5189 | 3.4436 | 3.5565 | 1.0614 |
| ReFL‡ | 0.3833 | 15.2488 | 23.5015 | 3.4810 | 3.6872 | 1.4239 |
| DRaFT-LV* | 0.3506 | 14.6022 | 23.0747 | 3.4691 | 3.6519 | 1.3008 |
| DRTune* | 0.3828 | 15.2541 | 23.4711 | 3.4738 | 3.6667 | 1.4320 |
| LeapAlign | 0.3915 | 15.5780 | 23.6180 | 3.4896 | 3.7182 | 1.4736 |
阈值消融(Appendix Table 5)
| HPSv2.1 | HPSv3 | PickScore | ImageReward | |
|---|---|---|---|---|
| 0.35 | 0.3860 | 15.3635 | 23.4735 | 1.3510 |
| 0.55 | 0.4092 | 15.7678 | 23.7137 | 1.5104 |
| 0.75 | 0.4091 | 15.7274 | 23.7061 | 1.4844 |
| 0.95 | 0.4023 | 15.7254 | 23.5082 | 1.3888 |
关键发现、局限与结论
- 主结果: LeapAlign 在 Flux 上取得最高 HPSv2.1、HPSv3、PickScore、UnifiedReward-Alignment、UnifiedReward-IQ 与 GenEval Overall;虽只用 HPSv2.1 reward 训练,仍超过用三 reward 的 MixGRPO 的 HPSv2.1 / PickScore,并在 ImageReward 上保持竞争力。
- 消融: 、两步 leap、真实 reward input、 similarity weighting、全区间随机 timestep selection 均是关键设计; 在 HPSv2.1 设置下最好。
- 局限: 需要可微 reward model;不可微 reward 需要未来借助 differentiable value model 等扩展。作者还指出未来会把 LeapAlign 实现并改进到 video generation。