Unified Latents (UL): How to train your latents
Paper: arXiv:2602.17270 Code: 代码搜索未找到开源实现(已搜索论文标题、作者、
2602.17270、GitHub、Papers with Code / Hugging Face paper 页面)
1. Motivation (研究动机)
潜在扩散模型的核心收益是把高分辨率图像/视频压缩到低维 latent 后再建模,但现有 latent 的训练方式往往把两个目标混在一起:一方面 latent 要保留足够信息以便高保真重建,另一方面 latent 分布又必须足够简单,扩散 prior 才容易学习。Stable Diffusion 类 VAE 通常用 GAN decoder、通道瓶颈和较小 KL penalty 调 latent 信息量;这种做法有效但不够原则化,很难直接回答「latent 里到底有多少 bit」以及「这个 bitrate 是否适合给定 base model」。
当前方法的具体瓶颈有三类。第一,GAN/VAE autoencoder 的 reconstruction quality、latent Gaussianity 和 downstream generation quality 往往要靠经验调参共同折中,缺少可解释的 bitrate 控制。第二,若 decoder 太强,VAE 容易出现 posterior collapse,latent 不被充分使用;若 latent 太强,base diffusion model 又会被复杂 latent 分布拖慢。第三,很多比较把 autoencoder training data、encoder pretraining、GAN decoder 与 diffusion decoder 混在一起,导致很难判断 latent 本身是否更适合生成。
本文要解决的具体问题是:如何从头训练一个对生成建模友好的连续 latent 表示,使 encoder、prior 和 decoder 的目标在同一个 diffusion/ELBO 框架下对齐,并让 latent bitrate 可解释、可控。作者的切入点不是换一个更强的 encoder,而是把 VAE 的 prior 正则化项也写成 diffusion prior 的 loss,再让 decoder 也用 diffusion model 重建图像。
这个问题值得做,因为 latent 的信息量直接决定 training compute 与生成质量的 trade-off:latent 太强,重建好但 base model 学不动;latent 太弱,base model 容易建模但 decoder 需要补太多细节。UL 把这个 trade-off 显式化后,可以在 ImageNet-512 上达到 1.4 FID,并在 Kinetics-600 上达到 1.3 FVD,同时用比 Stable Diffusion latents 更少的 base-model training FLOPs。
从应用角度看,若这种 latent training 方式成立,后续的图像/视频 foundation model 可以把「给定算力下应该训练多高 bitrate 的 latent」当作一个可研究的 scaling 问题,而不是每次重新经验搜索 autoencoder 架构与 KL 权重。
2. Idea (核心思想)
核心洞察:不要把 latent 正则化当作一个手工 KL penalty 或通道瓶颈,而是把 encoder 输出的最小噪声精度与 latent diffusion prior 的最大精度绑定;这样 VAE 的 KL 项可以化成一个 diffusion prior 的加权 MSE,并给 latent bitrate 一个紧的上界。decoder 则反过来允许使用重加权 diffusion loss,让视觉上不重要的高频细节更倾向于由 decoder 建模。
可以把 UL 理解成一个「双 diffusion 的 VAE」:prior diffusion model 不是只负责生成 latent,而是在训练中充当信息论正则器;decoder diffusion model 不是普通像素 MSE decoder,而是一个能建模条件图像分布的强 decoder。二者的张力由 loss factor 与 sigmoid bias 控制,因此 bitrate 不再是隐含在 channel count 或 KL coefficient 里的副作用。
关键创新可以概括为三点:1) deterministic encoder 先输出 ,再在 加固定 Gaussian noise 形成 ;2) prior diffusion model 必须用未重加权的 ELBO(),避免 encoder 把信息藏到被 discount 的 noise level;3) decoder diffusion model 使用 sigmoid/loss-factor weighting,通过 和 bias 控制 latent bitrate。
与 LDM / Stable Diffusion VAE 的根本区别是,后者主要靠 autoencoder 结构和小 KL heuristics 控制 latent,而 UL 直接让 diffusion prior 参与 latent 正则化。与 LSGM 这类 diffusion-prior VAE 相比,UL 不学习任意 encoder posterior,而是用 deterministic mean + fixed noise,避免单独的 encoder entropy term 带来的训练不稳定。
与 RAE / DINO / SigLIP 这类 representation-as-latent 路线相比,UL 的目标不是把语义 encoder 直接拿来生成,而是从一开始就以 downstream diffusion modeling efficiency 为目标学习 latent。因此它更关注 bitrate、rFID/gFID 和 training compute 的共同最优,而不是只追求重建或语义特征质量。
3. Method (方法)
3.1 总体框架:encoder + diffusion prior + diffusion decoder
Figure 1 解读:左侧 encoder 把图像 压缩成 clean latent 。上支路把 加噪成 ,训练 diffusion prior 预测 latent;下支路把 在 加固定噪声得到 ,并作为 diffusion decoder 的条件来重建图像。这个图说明 UL 不是先训练一个独立 VAE 再训练 diffusion,而是在同一阶段联合训练 encoder、prior、decoder。
VAE 视角下,目标从标准 ELBO 出发:
UL 的设计是让 和 都由 diffusion model 学习:prior 管 latent 分布,decoder 管从 latent 到图像的条件生成。
3.2 Encoding & prior:把 encoder 噪声和 prior 精度绑定
Figure 2 解读:图中 先变成 ,再得到稍带噪声的 ;diffusion prior 建模从纯噪声 到 的路径,decoder 则从 重建图像。关键是 不是任意 posterior sample,而是 encoder mean 加固定噪声,因此 prior 的最小噪声层级就是 latent 的最大信息精度。
具体地,encoder 输出 deterministic latent:
论文使用最终 log-SNR ,对应 variance-preserving schedule 下:
此时 VAE 的 latent KL 项被上界为 diffusion prior loss:
由于这是 prior 对 latent bitrate 的正则化,论文强调 prior 侧使用未重加权 ELBO,即 。直觉上,如果 prior 也像图像 diffusion 一样 discount 某些噪声层,encoder 就会把信息塞到便宜的层级里,bitrate 上界会失真。
3.3 Decoder:用 diffusion decoder 接管重建项
Figure 3 解读:decoder 的 -MSE 权重写成 。图中权重大于 1 的区域更鼓励 latent/decoder 分担信息,小于 1 的区域被 discount,意味着一些视觉高频细节更便宜地交给 decoder 建模。 越大,latent 通常越 informative,重建更好但 base model 更难学。
decoder 在图像空间做 diffusion:
并以 为条件预测 :
直觉上,prior loss 是「让 latent 易建模」的压力,decoder loss 是「让图像可重建」的压力;二者通过 loss factor / sigmoid bias 调整平衡。论文报告多数实验中只需要 到 的小幅放大。
3.4 两阶段训练与采样
Stage 1 联合训练 encoder、prior 和 decoder;Stage 2 冻结 encoder/decoder,再训练一个更大的 base diffusion model 来生成 latent。这样做的原因是 Stage 1 prior 必须用 ELBO weighting 才能正则化 latent,但这种 prior 直接采样质量不够好;Stage 2 的 base model 可以像标准 LDM 一样用 sigmoid weighting,并且由于只需要冻结的 encoder,模型和 batch size 都可以更大。
论文级伪代码如下;注意:代码搜索未找到开源实现,因此这是按论文 Algorithm 1/2 转写的 PyTorch 风格伪代码,不是某个仓库的真实源码。
import torch
import torch.nn.functional as F
def vp_noisy(x, alpha, sigma, eps):
return alpha * x + sigma * eps
def unified_latents_train_step(encoder, prior, decoder, batch, opt, sched_z, sched_x,
loss_factor=1.5, decoder_bias=0.0):
x = batch["image"]
z_clean = encoder(x)
# Prior branch: unweighted latent ELBO / bitrate regularizer.
t_z = torch.rand(x.shape[0], device=x.device)
eps_z = torch.randn_like(z_clean)
alpha_z, sigma_z, logsnr_z, dlogsnr_z_dt = sched_z(t_z)
z_t = vp_noisy(z_clean, alpha_z, sigma_z, eps_z)
z_hat = prior(z_t, t_z)
prior_weight = -dlogsnr_z_dt * torch.exp(logsnr_z) / 2.0
prior_loss = (prior_weight * (z_clean - z_hat).pow(2).flatten(1).mean(1)).mean()
# Decoder branch: fixed-noise z0 conditions image-space diffusion decoder.
t_x = torch.rand(x.shape[0], device=x.device)
eps_x = torch.randn_like(x)
eps_z0 = torch.randn_like(z_clean)
alpha_0, sigma_0, _, _ = sched_z(torch.zeros_like(t_z)) # lambda_z(0)=5 in paper
z_0 = vp_noisy(z_clean, alpha_0, sigma_0, eps_z0)
alpha_x, sigma_x, logsnr_x, dlogsnr_x_dt = sched_x(t_x)
x_t = vp_noisy(x, alpha_x, sigma_x, eps_x)
x_hat = decoder(x_t, t_x, z_0)
decoder_weight = loss_factor * torch.sigmoid(decoder_bias - logsnr_x)
decoder_loss = (decoder_weight * (x - x_hat).pow(2).flatten(1).mean(1)).mean()
loss = prior_loss + decoder_loss
opt.zero_grad()
loss.backward()
opt.step()
return {"loss": loss, "prior_loss": prior_loss, "decoder_loss": decoder_loss}
def sample_unified_latents(base_model, decoder, latent_shape, image_shape, sched_z, sched_x):
z_1 = torch.randn(latent_shape)
z_0 = base_model.sample_reverse_diffusion(z_1, sched_z)
x_1 = torch.randn(image_shape)
x = decoder.sample_reverse_diffusion(x_1, condition=z_0, schedule=sched_x)
return x3.5 Code-to-paper mapping / code search
Code mapping status: 代码搜索未找到开源实现;不能给出真实
Source File/Class/Function对应关系。下表仅记录论文组件级映射,避免伪造文件路径。
| Paper Concept | Source File | Key Class/Function |
|---|---|---|
| deterministic encoder + fixed-noise | 公开代码不可用 | 论文 Algorithm 1;Section 3.1 |
| unweighted latent diffusion prior loss | 公开代码不可用 | 论文 Eq. prior KL bound;Algorithm 1 prior branch |
| conditional diffusion decoder loss | 公开代码不可用 | 论文 decoder reconstruction bound;Algorithm 1 decoder branch |
| Stage-2 base latent diffusion model | 公开代码不可用 | Section 3.3 Base model;Algorithm 2 sampling |
| loss-factor / sigmoid-bias bitrate control | 公开代码不可用 | Figure 3;Table 2 bitrate sweep |
4. Experimental Setup (实验设置)
数据集与任务
- ImageNet-512:主图像生成与 latent bitrate/shape ablation。论文未详细说明训练样本数;使用 ImageNet-512 是为了和 prior image-generation work 的 training compute / FID 曲线对齐。
- Internal Text-To-Image datasets:用于大规模 TTI scaling,论文未公开数据集名称和样本数;AutoEncoder loss factor sweep 为 1.25–1.7,每个 AE 训练 100/300/970 GFlops base model,并用 30k unguided samples 计算 CLIP 与 FID。
- Kinetics-600:视频生成实验;使用 frames、 videos,latent 下采样为 ;按 Video Diffusion 设置 condition on 5 frames, generate 11 frames。论文未详细说明训练样本数。
Baselines
图像侧对比包括 Stable Diffusion latents、Pixel diffusion/no latents、UNet(SD)、EDM2-S/XXL、DiT-XL/2、SiD2、RAE;视频侧对比 MAGVIT-v2、W.A.L.T.、Video Diffusion、RIN。内部消融包括 prior model stop-gradient / KL 替代、去掉 noisy latents、不同 AE training data、learned variance、MSE decoder loss、normal prior、latent channel count 与 spatial downsampling。
指标
- gFID:base model 采样图像的 FID;论文用它衡量 generation quality。
- rFID:autoencoder reconstruction 的 FID;使用相同 dataset samples 做 reconstruction 与 FID reference。
- PSNR:重建图像与原图像的像素级保真度。
- CLIP score:text-to-image 的文本对齐指标。
- FVD:视频生成质量指标;Kinetics-600 上报告 training cost vs FVD。
- Training cost:以 zettaFLOPs per model 度量;Figure 4 明确不包含 autoencoder training cost,并假设一次训练 iteration 是一次模型 eval 的 3 倍开销。
模型与训练配置
论文给出的结构配置如下;学习率、batch size、optimizer、GPU 类型/数量和总训练步数论文未详细说明。
- Encoder / Decoder input:encoder 和 decoder 都使用 patching 以节省 compute。
- Encoder:ResNet,[128, 256, 512, 512] channels;downsampling stage 使用 2 residual blocks,final stage 使用 3 blocks。
- Prior model:single-level ViT,8 blocks,1024 channels;Stage 1 prior loss 使用 unweighted ELBO。
- Base model:2-stage ViT,[512, 1024] channels,[6, 16] blocks;两阶段 dropout 都为 0.1;Stage 2 在 frozen encoder/decoder latents 上训练。
- Decoder:UVit;conv down/up stages channels 为 [128, 256, 512];middle transformer 为 8 blocks、1024 channels;dropout=0.1。
- Latent precision:,对应 ;loss factor 多数实验为 1.3–1.7。
5. Experimental Results (实验结果)
5.1 ImageNet-512 training efficiency
---How-to-train-your-latents/fig4_imagenet_compute.png)
Figure 4 解读:绿色点是 UL,紫色点是 Stable-Diffusion-style baselines,蓝色点是其他 previous methods。UL small/medium 在显著更低的 base-model training cost 下达到约 1.4–1.6 区间 FID;论文摘要报告 ImageNet-512 上达到 1.4 FID。注意该图不计入 autoencoder training cost,所以它主要说明「给定 latent 后训练 base model」的效率。
5.2 Text-to-image quality and alignment
Figure 5 解读:这是 UL text-to-image 模型的 hand-picked samples,展示 diffusion decoder + UL latent 并没有明显破坏视觉细节,青蛙毛衣、巴黎铁塔前小猫、冲浪熊猫、火车等 prompt 都能被生成。
Figure 6 解读:左图是不同 base model size 下 loss factor 与 gFID@30K 的关系,右图是 loss factor 与 CLIP 的关系。小模型更偏好低 bitrate / 低 loss factor;大模型对 bitrate 更不敏感,说明 base model capacity 越大,越能吃下更 informative 的 latent。
| Latents | gFID@30K | CLIP |
|---|---|---|
| UL (LF=1.5) | 4.1 | 27.1 |
| Pixel (no latents) | 5.0 | 27.0 |
| StableDiffusion | 6.8 | 27.0 |
Table 1 说明,在同一 text-to-image 评价中,UL (LF=1.5) 的 gFID@30K 明显优于 pixel diffusion 和 Stable Diffusion latents,CLIP 也略高。
5.3 Bitrate / reconstruction / generation trade-off
---How-to-train-your-latents/lf_recon_sweep.png)
Figure 7 解读:loss factor 从 1.3 增加到 2.1 时,公交车上的小字和边缘细节逐步变清晰,说明 higher-bitrate latents 能提升 reconstruction fidelity;但这不必然提升 generation,因为 base model 会更难建模。
| LF | bits/pixel | rFID@50k | PSNR | gFID small | gFID medium |
|---|---|---|---|---|---|
| 1.3 | 0.035 | 0.79 | 25.7 | 1.42 | 1.37 |
| 1.5 | 0.059 | 0.47 | 27.6 | 1.54 | 1.31 |
| 1.7 | 0.083 | 0.36 | 28.9 | 1.77 | 1.38 |
| 1.9 | 0.101 | 0.31 | 29.6 | 2.02 | 1.45 |
| 2.1 | 0.116 | 0.27 | 30.1 | 2.38 | 1.58 |
Table 2 的关键结论是:更高 LF 持续提升 rFID/PSNR,但 small base model 的最佳 gFID 在 LF=1.3,而 medium base model 的最佳 gFID 在 LF=1.5。这正是 UL 想显式控制的 reconstruction-modeling trade-off。
Figure 8 解读:三联图分别看 gFID、rFID、PSNR 随 latent bits/pixel 的变化;对于 small model,gFID 在中低 bitrate 有最优点,而 rFID/PSNR 随 bitrate 单调改善或近似改善。它说明「重建最好」不是「生成最好」。
5.4 Latent shape and AE ablations
| # chan | rFID | gFID@50K |
|---|---|---|
| 4 | 7.19 | - |
| 8 | 1.53 | - |
| 16 | 0.54 | 1.76 |
| 32 | 0.42 | 1.60 |
| 64 | 0.48 | 1.77 |
Table 3 显示,在固定 spatial downsampling 时,16/32/64 channels 的 generation 差异不大;4/8 channels 因信息不足导致 reconstruction 明显变差。
| Latent shape | rFID@50K | gFID@50K |
|---|---|---|
| 0.40 | 2.12 | |
| 0.41 | 1.63 | |
| 1.41 | 1.74 |
Table 4 显示, 在 rFID 接近 的同时更容易被 decoder/base model 建模,gFID 更低。
| Prior | Reconstruction loss | Latent bpd with prior | Latent bpd with base model | rFID@50K | gFID@50K |
|---|---|---|---|---|---|
| Diffusion | Diffusion | 0.079 | 0.079 | 0.86 | 1.4 |
| Diffusion | MSE | 0.072 | 0.072 | 1.1 | 2.4 |
| Normal | Diffusion | 0.39 | 0.26 | 0.83 | 2.5 |
Table 5 说明 diffusion prior + diffusion reconstruction loss 是最稳的组合;MSE reconstruction 或 normal prior 虽然可能得到可接受 rFID,但 gFID 明显恶化。
| Ablation | bits/pixel | rFID@50k | PSNR | gFID@50k |
|---|---|---|---|---|
| UL baseline (LF=1.5) | 0.059 | 0.47 | 27.6 | 1.54 |
| A. prior model | 0.121 | 1.81 | nan | 7.80 |
| B. noisy latents | 0.008 | 28.27 | nan | - |
| C. ImageNet data | 0.034 | 1.37 | 24.7 | 1.63 |
| D. learned variance | 0.060 | 0.69 | nan | 1.81 |
Table 6 最能证明组件必要性:去掉 prior 对 encoder 的有效 regularization 后 gFID 退化到 7.80;去掉 noisy latents 后 bitrate estimate 失效,rFID 变成 28.27;learned variance 不如 fixed variance 稳定。
5.5 Video generation and additional samples
---How-to-train-your-latents/fig9_kinetics_compute.png)
Figure 9 解读:在 Kinetics-600 上,UL 曲线位于 MAGVIT-v2、W.A.L.T.、Video Diffusion、RIN 等 prior work 下方,表示更低 FVD 或更少 training cost。论文明确报告 UL small 已达到 1.7 FVD,UL medium 达到 1.3 FVD,并称其为当时 SOTA。
Figure 10 解读:这是非 cherry-picked text-to-image generations(guidance=2),用于补充 Figure 5 的定性结果。它展示 UL 训练出的 latent 在多个 prompt 上能保持较稳定的视觉质量,但论文没有把这些样本作为定量指标来源。
5.6 限制与结论
作者指出三个主要限制。第一,较弱/低 bitrate latent 更容易建模,但这可能把一部分建模负担转移给 decoder;这会影响如何公平比较不同 autoencoder。第二,不同方法的 AE 训练数据差异很大:Stable Diffusion AE 用大规模 web dataset,而本文多数 ImageNet 实验只用 ImageNet;DINO / SigLIP / RAE 等 semi-supervised encoder 又引入外部数据。第三,diffusion decoder 的采样成本比 GAN decoder 高一个数量级;若没有 decoder distillation,UL 的端到端推理成本会高于标准 LDM。
总体结论:UL 的价值不只是某个 FID 数字,而是提供了一套用 diffusion prior 正则化 latent、用 diffusion decoder 重建图像、再用 Stage-2 base model 生成 latent 的统一训练框架。它把 latent bitrate 从经验调参变成可解释的超参数,并在图像/视频上显示了更好的 training-efficiency trade-off。