Unified Latents (UL): How to train your latents

Paper: arXiv:2602.17270 Code: 代码搜索未找到开源实现(已搜索论文标题、作者、2602.17270、GitHub、Papers with Code / Hugging Face paper 页面)

1. Motivation (研究动机)

潜在扩散模型的核心收益是把高分辨率图像/视频压缩到低维 latent 后再建模,但现有 latent 的训练方式往往把两个目标混在一起:一方面 latent 要保留足够信息以便高保真重建,另一方面 latent 分布又必须足够简单,扩散 prior 才容易学习。Stable Diffusion 类 VAE 通常用 GAN decoder、通道瓶颈和较小 KL penalty 调 latent 信息量;这种做法有效但不够原则化,很难直接回答「latent 里到底有多少 bit」以及「这个 bitrate 是否适合给定 base model」。

当前方法的具体瓶颈有三类。第一,GAN/VAE autoencoder 的 reconstruction quality、latent Gaussianity 和 downstream generation quality 往往要靠经验调参共同折中,缺少可解释的 bitrate 控制。第二,若 decoder 太强,VAE 容易出现 posterior collapse,latent 不被充分使用;若 latent 太强,base diffusion model 又会被复杂 latent 分布拖慢。第三,很多比较把 autoencoder training data、encoder pretraining、GAN decoder 与 diffusion decoder 混在一起,导致很难判断 latent 本身是否更适合生成。

本文要解决的具体问题是:如何从头训练一个对生成建模友好的连续 latent 表示,使 encoder、prior 和 decoder 的目标在同一个 diffusion/ELBO 框架下对齐,并让 latent bitrate 可解释、可控。作者的切入点不是换一个更强的 encoder,而是把 VAE 的 prior 正则化项也写成 diffusion prior 的 loss,再让 decoder 也用 diffusion model 重建图像。

这个问题值得做,因为 latent 的信息量直接决定 training compute 与生成质量的 trade-off:latent 太强,重建好但 base model 学不动;latent 太弱,base model 容易建模但 decoder 需要补太多细节。UL 把这个 trade-off 显式化后,可以在 ImageNet-512 上达到 1.4 FID,并在 Kinetics-600 上达到 1.3 FVD,同时用比 Stable Diffusion latents 更少的 base-model training FLOPs。

从应用角度看,若这种 latent training 方式成立,后续的图像/视频 foundation model 可以把「给定算力下应该训练多高 bitrate 的 latent」当作一个可研究的 scaling 问题,而不是每次重新经验搜索 autoencoder 架构与 KL 权重。

2. Idea (核心思想)

核心洞察:不要把 latent 正则化当作一个手工 KL penalty 或通道瓶颈,而是把 encoder 输出的最小噪声精度与 latent diffusion prior 的最大精度绑定;这样 VAE 的 KL 项可以化成一个 diffusion prior 的加权 MSE,并给 latent bitrate 一个紧的上界。decoder 则反过来允许使用重加权 diffusion loss,让视觉上不重要的高频细节更倾向于由 decoder 建模。

可以把 UL 理解成一个「双 diffusion 的 VAE」:prior diffusion model 不是只负责生成 latent,而是在训练中充当信息论正则器;decoder diffusion model 不是普通像素 MSE decoder,而是一个能建模条件图像分布的强 decoder。二者的张力由 loss factor 与 sigmoid bias 控制,因此 bitrate 不再是隐含在 channel count 或 KL coefficient 里的副作用。

关键创新可以概括为三点:1) deterministic encoder 先输出 ,再在 加固定 Gaussian noise 形成 ;2) prior diffusion model 必须用未重加权的 ELBO(),避免 encoder 把信息藏到被 discount 的 noise level;3) decoder diffusion model 使用 sigmoid/loss-factor weighting,通过 和 bias 控制 latent bitrate。

与 LDM / Stable Diffusion VAE 的根本区别是,后者主要靠 autoencoder 结构和小 KL heuristics 控制 latent,而 UL 直接让 diffusion prior 参与 latent 正则化。与 LSGM 这类 diffusion-prior VAE 相比,UL 不学习任意 encoder posterior,而是用 deterministic mean + fixed noise,避免单独的 encoder entropy term 带来的训练不稳定。

与 RAE / DINO / SigLIP 这类 representation-as-latent 路线相比,UL 的目标不是把语义 encoder 直接拿来生成,而是从一开始就以 downstream diffusion modeling efficiency 为目标学习 latent。因此它更关注 bitrate、rFID/gFID 和 training compute 的共同最优,而不是只追求重建或语义特征质量。

3. Method (方法)

3.1 总体框架:encoder + diffusion prior + diffusion decoder

Figure 1 解读:左侧 encoder 把图像 压缩成 clean latent 。上支路把 加噪成 ,训练 diffusion prior 预测 latent;下支路把 加固定噪声得到 ,并作为 diffusion decoder 的条件来重建图像。这个图说明 UL 不是先训练一个独立 VAE 再训练 diffusion,而是在同一阶段联合训练 encoder、prior、decoder。

VAE 视角下,目标从标准 ELBO 出发:

UL 的设计是让 都由 diffusion model 学习:prior 管 latent 分布,decoder 管从 latent 到图像的条件生成。

3.2 Encoding & prior:把 encoder 噪声和 prior 精度绑定

Figure 2 解读:图中 先变成 ,再得到稍带噪声的 ;diffusion prior 建模从纯噪声 的路径,decoder 则从 重建图像。关键是 不是任意 posterior sample,而是 encoder mean 加固定噪声,因此 prior 的最小噪声层级就是 latent 的最大信息精度。

具体地,encoder 输出 deterministic latent:

论文使用最终 log-SNR ,对应 variance-preserving schedule 下:

此时 VAE 的 latent KL 项被上界为 diffusion prior loss:

由于这是 prior 对 latent bitrate 的正则化,论文强调 prior 侧使用未重加权 ELBO,即 。直觉上,如果 prior 也像图像 diffusion 一样 discount 某些噪声层,encoder 就会把信息塞到便宜的层级里,bitrate 上界会失真。

3.3 Decoder:用 diffusion decoder 接管重建项

Figure 3 解读:decoder 的 -MSE 权重写成 。图中权重大于 1 的区域更鼓励 latent/decoder 分担信息,小于 1 的区域被 discount,意味着一些视觉高频细节更便宜地交给 decoder 建模。 越大,latent 通常越 informative,重建更好但 base model 更难学。

decoder 在图像空间做 diffusion:

并以 为条件预测

直觉上,prior loss 是「让 latent 易建模」的压力,decoder loss 是「让图像可重建」的压力;二者通过 loss factor / sigmoid bias 调整平衡。论文报告多数实验中只需要 的小幅放大。

3.4 两阶段训练与采样

Stage 1 联合训练 encoder、prior 和 decoder;Stage 2 冻结 encoder/decoder,再训练一个更大的 base diffusion model 来生成 latent。这样做的原因是 Stage 1 prior 必须用 ELBO weighting 才能正则化 latent,但这种 prior 直接采样质量不够好;Stage 2 的 base model 可以像标准 LDM 一样用 sigmoid weighting,并且由于只需要冻结的 encoder,模型和 batch size 都可以更大。

论文级伪代码如下;注意:代码搜索未找到开源实现,因此这是按论文 Algorithm 1/2 转写的 PyTorch 风格伪代码,不是某个仓库的真实源码。

import torch
import torch.nn.functional as F
 
 
def vp_noisy(x, alpha, sigma, eps):
    return alpha * x + sigma * eps
 
 
def unified_latents_train_step(encoder, prior, decoder, batch, opt, sched_z, sched_x,
                               loss_factor=1.5, decoder_bias=0.0):
    x = batch["image"]
    z_clean = encoder(x)
 
    # Prior branch: unweighted latent ELBO / bitrate regularizer.
    t_z = torch.rand(x.shape[0], device=x.device)
    eps_z = torch.randn_like(z_clean)
    alpha_z, sigma_z, logsnr_z, dlogsnr_z_dt = sched_z(t_z)
    z_t = vp_noisy(z_clean, alpha_z, sigma_z, eps_z)
    z_hat = prior(z_t, t_z)
    prior_weight = -dlogsnr_z_dt * torch.exp(logsnr_z) / 2.0
    prior_loss = (prior_weight * (z_clean - z_hat).pow(2).flatten(1).mean(1)).mean()
 
    # Decoder branch: fixed-noise z0 conditions image-space diffusion decoder.
    t_x = torch.rand(x.shape[0], device=x.device)
    eps_x = torch.randn_like(x)
    eps_z0 = torch.randn_like(z_clean)
    alpha_0, sigma_0, _, _ = sched_z(torch.zeros_like(t_z))  # lambda_z(0)=5 in paper
    z_0 = vp_noisy(z_clean, alpha_0, sigma_0, eps_z0)
    alpha_x, sigma_x, logsnr_x, dlogsnr_x_dt = sched_x(t_x)
    x_t = vp_noisy(x, alpha_x, sigma_x, eps_x)
    x_hat = decoder(x_t, t_x, z_0)
    decoder_weight = loss_factor * torch.sigmoid(decoder_bias - logsnr_x)
    decoder_loss = (decoder_weight * (x - x_hat).pow(2).flatten(1).mean(1)).mean()
 
    loss = prior_loss + decoder_loss
    opt.zero_grad()
    loss.backward()
    opt.step()
    return {"loss": loss, "prior_loss": prior_loss, "decoder_loss": decoder_loss}
 
 
def sample_unified_latents(base_model, decoder, latent_shape, image_shape, sched_z, sched_x):
    z_1 = torch.randn(latent_shape)
    z_0 = base_model.sample_reverse_diffusion(z_1, sched_z)
    x_1 = torch.randn(image_shape)
    x = decoder.sample_reverse_diffusion(x_1, condition=z_0, schedule=sched_x)
    return x

Code mapping status: 代码搜索未找到开源实现;不能给出真实 Source File / Class / Function 对应关系。下表仅记录论文组件级映射,避免伪造文件路径。

Paper ConceptSource FileKey Class/Function
deterministic encoder + fixed-noise 公开代码不可用论文 Algorithm 1;Section 3.1
unweighted latent diffusion prior loss公开代码不可用论文 Eq. prior KL bound;Algorithm 1 prior branch
conditional diffusion decoder loss公开代码不可用论文 decoder reconstruction bound;Algorithm 1 decoder branch
Stage-2 base latent diffusion model公开代码不可用Section 3.3 Base model;Algorithm 2 sampling
loss-factor / sigmoid-bias bitrate control公开代码不可用Figure 3;Table 2 bitrate sweep

4. Experimental Setup (实验设置)

数据集与任务

  • ImageNet-512:主图像生成与 latent bitrate/shape ablation。论文未详细说明训练样本数;使用 ImageNet-512 是为了和 prior image-generation work 的 training compute / FID 曲线对齐。
  • Internal Text-To-Image datasets:用于大规模 TTI scaling,论文未公开数据集名称和样本数;AutoEncoder loss factor sweep 为 1.25–1.7,每个 AE 训练 100/300/970 GFlops base model,并用 30k unguided samples 计算 CLIP 与 FID。
  • Kinetics-600:视频生成实验;使用 frames、 videos,latent 下采样为 ;按 Video Diffusion 设置 condition on 5 frames, generate 11 frames。论文未详细说明训练样本数。

Baselines

图像侧对比包括 Stable Diffusion latents、Pixel diffusion/no latents、UNet(SD)、EDM2-S/XXL、DiT-XL/2、SiD2、RAE;视频侧对比 MAGVIT-v2、W.A.L.T.、Video Diffusion、RIN。内部消融包括 prior model stop-gradient / KL 替代、去掉 noisy latents、不同 AE training data、learned variance、MSE decoder loss、normal prior、latent channel count 与 spatial downsampling。

指标

  • gFID:base model 采样图像的 FID;论文用它衡量 generation quality。
  • rFID:autoencoder reconstruction 的 FID;使用相同 dataset samples 做 reconstruction 与 FID reference。
  • PSNR:重建图像与原图像的像素级保真度。
  • CLIP score:text-to-image 的文本对齐指标。
  • FVD:视频生成质量指标;Kinetics-600 上报告 training cost vs FVD。
  • Training cost:以 zettaFLOPs per model 度量;Figure 4 明确不包含 autoencoder training cost,并假设一次训练 iteration 是一次模型 eval 的 3 倍开销。

模型与训练配置

论文给出的结构配置如下;学习率、batch size、optimizer、GPU 类型/数量和总训练步数论文未详细说明。

  • Encoder / Decoder input:encoder 和 decoder 都使用 patching 以节省 compute。
  • Encoder:ResNet,[128, 256, 512, 512] channels;downsampling stage 使用 2 residual blocks,final stage 使用 3 blocks。
  • Prior model:single-level ViT,8 blocks,1024 channels;Stage 1 prior loss 使用 unweighted ELBO。
  • Base model:2-stage ViT,[512, 1024] channels,[6, 16] blocks;两阶段 dropout 都为 0.1;Stage 2 在 frozen encoder/decoder latents 上训练。
  • Decoder:UVit;conv down/up stages channels 为 [128, 256, 512];middle transformer 为 8 blocks、1024 channels;dropout=0.1。
  • Latent precision,对应 ;loss factor 多数实验为 1.3–1.7。

5. Experimental Results (实验结果)

5.1 ImageNet-512 training efficiency

Figure 4 解读:绿色点是 UL,紫色点是 Stable-Diffusion-style baselines,蓝色点是其他 previous methods。UL small/medium 在显著更低的 base-model training cost 下达到约 1.4–1.6 区间 FID;论文摘要报告 ImageNet-512 上达到 1.4 FID。注意该图不计入 autoencoder training cost,所以它主要说明「给定 latent 后训练 base model」的效率。

5.2 Text-to-image quality and alignment

Figure 5 解读:这是 UL text-to-image 模型的 hand-picked samples,展示 diffusion decoder + UL latent 并没有明显破坏视觉细节,青蛙毛衣、巴黎铁塔前小猫、冲浪熊猫、火车等 prompt 都能被生成。

Figure 6 解读:左图是不同 base model size 下 loss factor 与 gFID@30K 的关系,右图是 loss factor 与 CLIP 的关系。小模型更偏好低 bitrate / 低 loss factor;大模型对 bitrate 更不敏感,说明 base model capacity 越大,越能吃下更 informative 的 latent。

LatentsgFID@30KCLIP
UL (LF=1.5)4.127.1
Pixel (no latents)5.027.0
StableDiffusion6.827.0

Table 1 说明,在同一 text-to-image 评价中,UL (LF=1.5) 的 gFID@30K 明显优于 pixel diffusion 和 Stable Diffusion latents,CLIP 也略高。

5.3 Bitrate / reconstruction / generation trade-off

Figure 7 解读:loss factor 从 1.3 增加到 2.1 时,公交车上的小字和边缘细节逐步变清晰,说明 higher-bitrate latents 能提升 reconstruction fidelity;但这不必然提升 generation,因为 base model 会更难建模。

LFbits/pixelrFID@50kPSNRgFID smallgFID medium
1.30.0350.7925.71.421.37
1.50.0590.4727.61.541.31
1.70.0830.3628.91.771.38
1.90.1010.3129.62.021.45
2.10.1160.2730.12.381.58

Table 2 的关键结论是:更高 LF 持续提升 rFID/PSNR,但 small base model 的最佳 gFID 在 LF=1.3,而 medium base model 的最佳 gFID 在 LF=1.5。这正是 UL 想显式控制的 reconstruction-modeling trade-off。

Figure 8 解读:三联图分别看 gFID、rFID、PSNR 随 latent bits/pixel 的变化;对于 small model,gFID 在中低 bitrate 有最优点,而 rFID/PSNR 随 bitrate 单调改善或近似改善。它说明「重建最好」不是「生成最好」。

5.4 Latent shape and AE ablations

# chanrFIDgFID@50K
47.19-
81.53-
160.541.76
320.421.60
640.481.77

Table 3 显示,在固定 spatial downsampling 时,16/32/64 channels 的 generation 差异不大;4/8 channels 因信息不足导致 reconstruction 明显变差。

Latent shaperFID@50KgFID@50K
0.402.12
0.411.63
1.411.74

Table 4 显示, 在 rFID 接近 的同时更容易被 decoder/base model 建模,gFID 更低。

PriorReconstruction lossLatent bpd with priorLatent bpd with base modelrFID@50KgFID@50K
DiffusionDiffusion0.0790.0790.861.4
DiffusionMSE0.0720.0721.12.4
NormalDiffusion0.390.260.832.5

Table 5 说明 diffusion prior + diffusion reconstruction loss 是最稳的组合;MSE reconstruction 或 normal prior 虽然可能得到可接受 rFID,但 gFID 明显恶化。

Ablationbits/pixelrFID@50kPSNRgFID@50k
UL baseline (LF=1.5)0.0590.4727.61.54
A. prior model0.1211.81nan7.80
B. noisy latents0.00828.27nan-
C. ImageNet data0.0341.3724.71.63
D. learned variance0.0600.69nan1.81

Table 6 最能证明组件必要性:去掉 prior 对 encoder 的有效 regularization 后 gFID 退化到 7.80;去掉 noisy latents 后 bitrate estimate 失效,rFID 变成 28.27;learned variance 不如 fixed variance 稳定。

5.5 Video generation and additional samples

Figure 9 解读:在 Kinetics-600 上,UL 曲线位于 MAGVIT-v2、W.A.L.T.、Video Diffusion、RIN 等 prior work 下方,表示更低 FVD 或更少 training cost。论文明确报告 UL small 已达到 1.7 FVD,UL medium 达到 1.3 FVD,并称其为当时 SOTA。

Figure 10 解读:这是非 cherry-picked text-to-image generations(guidance=2),用于补充 Figure 5 的定性结果。它展示 UL 训练出的 latent 在多个 prompt 上能保持较稳定的视觉质量,但论文没有把这些样本作为定量指标来源。

5.6 限制与结论

作者指出三个主要限制。第一,较弱/低 bitrate latent 更容易建模,但这可能把一部分建模负担转移给 decoder;这会影响如何公平比较不同 autoencoder。第二,不同方法的 AE 训练数据差异很大:Stable Diffusion AE 用大规模 web dataset,而本文多数 ImageNet 实验只用 ImageNet;DINO / SigLIP / RAE 等 semi-supervised encoder 又引入外部数据。第三,diffusion decoder 的采样成本比 GAN decoder 高一个数量级;若没有 decoder distillation,UL 的端到端推理成本会高于标准 LDM。

总体结论:UL 的价值不只是某个 FID 数字,而是提供了一套用 diffusion prior 正则化 latent、用 diffusion decoder 重建图像、再用 Stage-2 base model 生成 latent 的统一训练框架。它把 latent bitrate 从经验调参变成可解释的超参数,并在图像/视频上显示了更好的 training-efficiency trade-off。