LeapAlign: Post-Training Flow Matching Models at Any Generation Step by Building Two-Step Trajectories

1. Motivation(研究动机)

  • 现有方法的问题: Flow matching 采样轨迹是连续可微的,理论上可以把 reward gradient 沿 ODE 采样链直接反传;但完整长轨迹反传会导致显存成本过高和 gradient explosion。已有 direct-gradient 方法通常只更新靠近最终图像的晚期步,或者像 DRTune 一样通过 stop-gradient 丢掉 nested gradient,从而难以有效学习控制全局构图的早期步。
  • 本文要解决的问题: 在不展开完整长轨迹反传的情况下,让 reward 对任意生成步(特别是早期步)产生稳定、可负担的梯度,同时保留跨时间步交互信息,而不是完全切断 nested gradient。
  • 为什么重要: 早期 denoising / flow steps 决定图像布局、对象关系和组合结构;若只能微调晚期细节,偏好对齐和 GenEval 类 compositional alignment 会受限。LeapAlign 目标是把长轨迹压缩为两步可微 leap,降低显存和梯度风险,使 Flux 等多步 flow model 能更好地做 post-training alignment。

2. Idea(核心思想)

  • 核心洞察: 不必沿完整 ODE 轨迹反传;从一次 full rollout 中抽取 两个时刻,用 flow matching 的 one-step leap prediction 构造 的两步可微近似,再用 latent connector 保证前向值等于真实 rollout latent,梯度却只沿 leap prediction 传播。
  • 关键创新: (1) Two-step leap trajectory 把反传深度固定为两次 velocity evaluation;(2) gradient discounting 缩小 nested gradient,而不是删除;(3) trajectory-similarity weighting 对更接近真实长轨迹的 leap 样本赋更大训练权重。
  • 与代表性方法的本质区别: 相比 ReFL / DRaFT-LV 只在一个中间点做 one-step reward backprop,LeapAlign 可以通过随机 覆盖任意时间段;相比 DRTune 的 input stop-gradient,LeapAlign 保留经 缩放的 nested gradient,因此能利用跨步依赖信号。

3. Method(方法)

总体框架

直觉段落。 长轨迹反传的真正难点不是“reward 能否可微”,而是生成链太深导致每个 velocity prediction 的 Jacobian 级联放大。LeapAlign 的做法像给轨迹建一条可微“捷径”:真实 rollout 仍提供 这些可信锚点,但梯度只穿过 两次模型调用。这样既避免了完整采样链显存,又让早期 的模型预测能通过第二跳影响最终 ;相似度权重再过滤掉 leap approximation 误差过大的样本,降低错误 credit assignment。

Figure 1 解读: 图中 是完整生成轨迹。蓝色实线是两次 one-step leap prediction;橙色虚线是 latent connector,把 leap prediction 的前向值对齐到真实 rollout latent。紫色反传路径只经过两步 leap,因此 reward gradient 可高效流向早期 timestep。

Flow matching 与 one-step leap

Flow matching 以 Gaussian noise 和真实图像 插值得到:

速度场训练目标为:

rectified flow,one-step leap prediction 为:

Two-step leap trajectory + latent connector

随机选择 ,先预测:

再用 straight-through connector 对齐真实 latent:

这使前向值仍为 full rollout 的 ,但梯度通过 leap prediction 走。

Gradient discounting

无 discount 时,两步 leap 的最终梯度包含两类 single-step terms 和一个 nested term:

LeapAlign 不删除 nested term,而把第二跳改成:

对应梯度变为:

Hinge reward objective 与 trajectory-similarity weighting

为避免 reward hacking,论文采用 hinge-style objective:

Leap approximation 的可靠性用两处 connector error 衡量:

主结果总览与定性图

Figure 2 解读: 左图显示 compositional alignment 训练中的 reward improvement;中图比较 Flux 在多 evaluator 上的提升;右图展示 GenEval 雷达图。整体信息是:LeapAlign 不只是提高 in-domain HPSv2.1,也改善 PickScore、UnifiedReward、GenEval 等 out-of-domain / composition 指标。

Figure 3 解读: GenEval 定性图比较 base Flux、ReFL、DRaFT-LV、DRTune 与 LeapAlign。LeapAlign 更常生成满足对象数量、颜色、位置与属性绑定的图像,说明早期 timestep 更新确实影响全局布局。

组件消融图

Figure 4 解读: 在 HPSv2.1 上最佳; 去掉 nested gradient 后仍优于 DRTune,说明 leap 结构本身有用; 过大则 nested gradient 过强,稳定性下降。

Figure 5 解读: 一步 leap 信息不足,三步 leap 显存更高但性能不升;两步在 performance / memory 之间最好。

Figure 6 解读: 使用真实 rollout 得到的 作为 reward 输入优于用 等预测值,说明 reward 应评估最终生成结果,而不是 leap approximation 本身。

Figure 7 解读: 同时使用 的 similarity weight 优于只用单个 connector error 或不加权,说明两个锚点都能过滤不可靠 leap。

Figure 8 解读: 在完整 timestep range 随机选择 优于只覆盖局部区间,支持“任意生成步都应可被更新”的设计。

Figure 9 解读: 随机选择 比固定距离更好且更简单,避免只优化固定跨度的局部 dynamics。

Figure 10 解读: 只通过 nested gradient 更新第一跳时, 增大 gradient norm 并伤害性能; 在性能和梯度规模之间取得更好折中,验证 discounting 的必要性。

Figure 11 解读: 该图跟踪 GenEval score 随 fine-tuning 的提升,LeapAlign 相对 ReFL、DRaFT-LV、DRTune 提升更快且最终更高。

Figure 12 解读: 使用 HPSv3 reward 微调后的 Flux 生成样例,展示方法可从 CLIP-style reward 泛化到 VLM-based reward。

Figure 13 解读: 更多 HPSv3 reward 下的定性结果,强调 LeapAlign 不是只对 HPSv2.1 有效。

伪代码(基于开源参考实现)

Code reference: unofficial reference implementation main @ 0f706528 (2026-04-17) — pseudocode and mapping based on this commit; official RockeyCoss/leapalign main @ 3612c121 (2026-04-20) currently contains project page/assets, not training code.

import torch
 
 
@torch.no_grad()
def euler_rollout(v_theta, x1, num_steps=25):
    traj = [(1.0, x1)]
    x = x1
    for s in range(num_steps):
        t_cur = 1.0 - s / num_steps
        t_next = 1.0 - (s + 1) / num_steps
        t = torch.full((x.size(0),), t_cur, device=x.device, dtype=x.dtype)
        v = v_theta(x, t)
        x = x + (t_next - t_cur) * v
        traj.append((t_next, x))
    return traj
import torch
 
 
def build_leap_trajectory(v_theta, x_k, x_j, x_0, k, j, alpha=0.5):
    x_k = x_k.detach()
    x_j_true = x_j.detach()
    x_0_true = x_0.detach()
    t_k = torch.full((x_k.size(0),), k, device=x_k.device, dtype=x_k.dtype)
    t_j = torch.full((x_j_true.size(0),), j, device=x_j_true.device, dtype=x_j_true.dtype)
 
    v_k = v_theta(x_k, t_k)
    x_hat_j_given_k = x_k - (k - j) * v_k
    x_j_connected = x_hat_j_given_k + (x_j_true - x_hat_j_given_k).detach()
 
    x_j_for_v = alpha * x_j_connected + (1.0 - alpha) * x_j_connected.detach()
    v_j = v_theta(x_j_for_v, t_j)
    x_hat_0_given_j = x_j_connected - j * v_j
    x_0_connected = x_hat_0_given_j + (x_0_true - x_hat_0_given_j).detach()
    return x_hat_j_given_k, x_hat_0_given_j, x_j_connected, x_0_connected
import torch
import torch.nn.functional as F
 
 
def trajectory_similarity_weight(x_j, x_0, x_hat_j_given_k, x_hat_0_given_j, tau=0.1):
    with torch.no_grad():
        d_j = (x_j - x_hat_j_given_k).abs().flatten(1).mean(dim=1)
        d_0 = (x_0 - x_hat_0_given_j).abs().flatten(1).mean(dim=1)
        tau_t = torch.as_tensor(tau, device=d_j.device, dtype=d_j.dtype)
        return 1.0 / (torch.maximum(d_j, tau_t) + torch.maximum(d_0, tau_t))
 
 
def leap_align_loss(x_0_connected, x_j, x_0, x_hat_j_given_k, x_hat_0_given_j, reward_fn, lam=0.55, tau=0.1):
    reward = reward_fn(x_0_connected)
    raw = F.relu(lam - reward)
    w_sim = trajectory_similarity_weight(x_j, x_0, x_hat_j_given_k, x_hat_0_given_j, tau)
    return (w_sim * raw).mean()
import torch
 
 
def leap_align_step(v_theta, x1, reward_fn, num_ode_steps=25, alpha=0.5, lam=0.55, tau=0.1):
    traj = euler_rollout(v_theta, x1, num_steps=num_ode_steps)
    i_k = torch.randint(1, num_ode_steps + 1, (1,)).item()
    i_j = torch.randint(0, i_k, (1,)).item()
    k, j = i_k / num_ode_steps, i_j / num_ode_steps
    x_k = traj[num_ode_steps - i_k][1]
    x_j = traj[num_ode_steps - i_j][1]
    x_0 = traj[-1][1]
    xhj, xh0, xjc, x0c = build_leap_trajectory(v_theta, x_k, x_j, x_0, k, j, alpha)
    return leap_align_loss(x0c, x_j, x_0, xhj, xh0, reward_fn, lam, tau)
 
 
def train_update(v_theta, x1, reward_fn, optimizer):
    loss = leap_align_step(v_theta, x1, reward_fn)
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(v_theta.parameters(), 1.0)
    optimizer.step()
    return loss.detach()

Code-to-paper mapping table:

Paper ConceptSource FileKey Class/Function
Full ODE rollout leap_align.pyeuler_rollout
Random sampling over rollout gridleap_align.pysample_leap_timesteps
Eq. (4)–(7) two-step leap + latent connectorsleap_align.pybuild_leap_trajectory, LeapTrajectory
Eq. (9) gradient discountingleap_align.pybuild_leap_trajectory (x_j_for_v)
Eq. (12) trajectory similarityleap_align.pytrajectory_similarity_weight
Eq. (11), (13) hinge reward lossleap_align.pyleap_align_loss, LeapAlignLoss
End-to-end loss constructionleap_align.pyleap_align_step
Optimizer backward / grad clip / AdamW / EMAtrain_leap_align.pymain training loop
Toy runnable training / optimizer / EMAtrain_leap_align.pymain, pretrain_flow, pretrain_reward

4. Experimental Setup(实验设置)

  • Datasets: General preference alignment 使用 HPDv2 抽样 50,000 prompts 训练,也使用 MJHQ-30k prompts;HPDv2 测试为 400 prompts × 4 images = 1,600 images,MJHQ-30k 随机留 500 prompts 测试。Compositional alignment 使用 GenEval official scripts 生成的 50,000 prompts 训练,类别比例 Position:Counting:Attribute Binding:Colors:Two Objects:Single Object = 7:5:3:1:1:0;测试按 GenEval 553 prompts × 4 images
  • Baselines: Policy-gradient:DanceGRPO、MixGRPO;Direct-gradient:ReFL、DRaFT-LV、DRTune。论文说明 DRaFT-LV / DRTune 无官方实现,由作者按伪代码复现;ReFL 从官方实现适配 Flux。
  • Metrics: HPSv2.1、HPSv3、PickScore、ImageReward 衡量 human preference;UnifiedReward-Alignment 衡量 image-text alignment;UnifiedReward-IQ 衡量 image quality;GenEval Overall 与六个子类(Single Object、Two Objects、Counting、Colors、Position、Attribute Binding)衡量组合生成正确性。
  • Training config: Base model 为 FLUX.1-dev;默认 reward HPSv2.1;AdamW lr=1e-5,batch size 64,weight decay 1e-4,EMA decay 0.995;训练 300 iterations on 16 GPUs。训练 rollout:720×720, 25 steps, CFG 3.5;评估:720×720, 50 steps, CFG 3.5。附录补充:HPSv2.1 时 ;PickScore / HPSv3 时 ;HPSv3 使用 lr=8e-6 对 HPSv2.1 / PickScore / HPSv3 分别为 0.55 / 0.4 / 13.5

5. Experimental Results(实验结果)

方法能力对比(Table 1)

MethodEarly StepsNested GradientLeap TrajectoryMulti-Step
ReFL
DRaFT-LV
DRTune
LeapAlign

Flux 主表(Table 2)

MethodHPSv2.1 ↑HPSv3 ↑PickScore ↑UR-Align ↑UR-IQ ↑ImageReward ↑GenEval Overall ↑
Flux0.307813.502022.79023.45143.57081.04550.6535
DanceGRPO0.345114.833623.11863.46603.61991.23470.6775
MixGRPO†0.369214.753023.51843.43933.62411.61550.7232
ReFL‡0.385215.512723.62993.47863.68701.34680.7011
DRaFT-LV*0.385915.369923.64373.48683.68871.33840.7024
DRTune*0.388215.560623.51853.47933.66791.35620.7101
LeapAlign0.409215.767823.71373.49843.72441.51040.7420

GenEval 子类(Table 2)

MethodSingle Obj.Two Obj.CountColorPositionAttrB
Flux99.3886.6266.8874.4719.5045.25
MixGRPO†99.6993.6980.0080.0524.2556.25
DRTune*99.3893.6973.1276.8627.5055.50
LeapAlign99.3896.4672.5080.5930.2566.00

不同 reward / prompt set(Table 3)

MethodHPSv2.1 ↑PickScore ↑HPSv3 ↑
Flux0.307822.790210.7624
ReFL‡0.385225.237311.7642
DRaFT-LV*0.385924.959611.2701
DRTune*0.388225.102112.0023
LeapAlign0.409225.758912.5855

SD3.5-M 泛化(Appendix Table 4)

MethodHPSv2.1 ↑HPSv3 ↑PickScore ↑UR-Align ↑UR-IQ ↑ImageReward ↑
SD3.5-M0.296712.284622.51893.44363.55651.0614
ReFL‡0.383315.248823.50153.48103.68721.4239
DRaFT-LV*0.350614.602223.07473.46913.65191.3008
DRTune*0.382815.254123.47113.47383.66671.4320
LeapAlign0.391515.578023.61803.48963.71821.4736

阈值消融(Appendix Table 5)

HPSv2.1HPSv3PickScoreImageReward
0.350.386015.363523.47351.3510
0.550.409215.767823.71371.5104
0.750.409115.727423.70611.4844
0.950.402315.725423.50821.3888

关键发现、局限与结论

  • 主结果: LeapAlign 在 Flux 上取得最高 HPSv2.1、HPSv3、PickScore、UnifiedReward-Alignment、UnifiedReward-IQ 与 GenEval Overall;虽只用 HPSv2.1 reward 训练,仍超过用三 reward 的 MixGRPO 的 HPSv2.1 / PickScore,并在 ImageReward 上保持竞争力。
  • 消融: 、两步 leap、真实 reward input、 similarity weighting、全区间随机 timestep selection 均是关键设计; 在 HPSv2.1 设置下最好。
  • 局限: 需要可微 reward model;不可微 reward 需要未来借助 differentiable value model 等扩展。作者还指出未来会把 LeapAlign 实现并改进到 video generation。