DMD2: Improved Distribution Matching Distillation for Fast Image Synthesis

Authors: Tianwei Yin, Michael Gharbi, Taesung Park, Richard Zhang, Eli Shechtman, Fredo Durand, William T. Freeman Affiliations: MIT, Adobe Research arXiv: 2405.14867 Project Page: tianweiy.github.io/dmd2 GitHub: tianweiy/DMD2 Venue: NeurIPS 2024 (Oral)

1. Motivation (研究动机)

1.1 问题背景

Diffusion model 生成质量出色,但采样需要数十到数百步迭代去噪,推理成本极高。蒸馏 (distillation) 是主流加速方案,目标是将多步 teacher diffusion model 压缩为 one-step 或 few-step 的 student generator。

现有蒸馏方法大致分两类:

  1. 轨迹匹配 (trajectory matching):如 Progressive Distillation、Consistency Distillation,让 student 学习 teacher 的 ODE 采样轨迹。需要 noise-image pair 的 regression loss,学生质量被 teacher 上限约束。
  2. 分布匹配 (distribution matching):如 DMD (Distribution Matching Distillation),直接对齐 student 和 teacher 的输出分布(KL 散度),不强制 one-to-one 轨迹对应。

1.2 DMD 的关键局限

DMD [Yin et al., 2024] 虽然取得了 SOTA 的 one-step 生成结果,但存在三个核心问题:

  1. 依赖 regression loss + 昂贵数据集:仅用 distribution matching loss 训练不稳定,需要额外的 regression loss 来稳定训练。这要求预先用 teacher 的 deterministic sampler 生成数百万 noise-image pairs(对 SDXL 需要约 700 A100-days),成本极高,且将 student 质量锁定在 teacher 的采样路径上。

  2. 仅支持 one-step 生成:无法扩展到 multi-step,对 SDXL 等大模型,one-step 映射过于复杂,质量受限。

  3. 训练-推理不匹配:multi-step 方法在训练时用真实图像加噪作为中间步输入,推理时中间步来自 student 自身生成,存在 domain gap。

1.3 核心矛盾

DMD 的 distribution matching 理念(不约束具体采样路径,只匹配分布)与 regression loss 的 trajectory-binding 本质上矛盾——regression loss 强迫 student 复现 teacher 的特定采样路径,限制了 student 超越 teacher 的可能性。

Figure 3 解读:DMD2 整体训练流程。左侧:few-step generator 从噪声生成图像(红色路径),同时用真实图像(绿色路径)做 GAN 训练。右侧:两个 score function——frozen real score (预训练 teacher)和 trainable fake score (在线训练的 critic),两者的差作为 distribution matching gradient(红色箭头)。下方:GAN discriminator 在 fake UNet 的 bottleneck 上加 classification head,区分真假图像(绿色箭头)。训练交替进行两步:(1) 用 DM gradient + GAN loss 更新 generator;(2) 用 denoising loss + GAN classification loss 更新 fake score function 和 discriminator。关键改进:fake score 的更新频率是 generator 的 5 倍(two time-scale update rule)。


2. Idea (核心思想)

2.1 核心洞察一:训练不稳定的根源是 fake critic 跟不上 generator

论文发现,去掉 regression loss 后训练不稳定(亮度/色彩剧烈波动、FID 不收敛),根本原因是 fake diffusion model 无法准确估计快速变化的 generator 输出分布。由于 generator 每步都在更新,fake score function 需要实时跟踪一个 non-stationary 分布,用相同的更新频率会导致 score 估计不准,进而产生有偏梯度。

解决方案:Two Time-scale Update Rule (TTUR)——让 fake score 的更新频率高于 generator(5:1),确保 critic 始终准确跟踪 generator 分布。

2.2 核心洞察二:GAN loss 弥补 teacher score 的近似误差

即使有了 TTUR,student 质量仍被 real score function(即 teacher diffusion model)的近似误差限制。Teacher 并不是数据分布的完美 score estimator,其误差会传导到 student。

解决方案:引入 GAN loss,discriminator 直接在真实图像和生成图像之间判别。GAN loss 不依赖 teacher,直接从真实数据获取监督信号,可以弥补 teacher 的 score 近似误差,甚至使 student 超越 teacher。

2.3 核心洞察三:Backward Simulation 消除训练-推理不匹配

Multi-step 蒸馏方法普遍存在 train/test domain gap:训练时中间步输入是真实图像加噪,推理时是 student 自身生成的中间结果。

解决方案:Backward Simulation——训练时不用真实图像加噪,而是让 student generator 从纯噪声开始实际运行几步,用其输出作为中间步输入。这样训练和推理看到的分布完全一致。

Figure 4 解读:左侧(传统方法):训练时用真实图像做 forward diffusion 得到中间步噪声输入,但推理时中间步来自 student 的 backward process,两者存在 domain gap(红色虚线框)。右侧(DMD2 的 Backward Simulation):训练时直接模拟推理过程——student 从噪声出发做 backward sampling 生成 “fake” 中间步样本,然后在这些样本上继续去噪并计算 loss。这样训练和推理的输入分布完全对齐(绿色框)。


3. Method (方法)

3.1 Background:DMD 的 Distribution Matching Loss

DMD 的核心是最小化 diffused 分布之间的 KL 散度。对于 generator ,其梯度可以表示为两个 score function 的差:

其中 是预训练 teacher(frozen), 是在 generator 输出上动态训练的 fake score function, 是 forward diffusion(加噪)。

在实现中,distribution matching gradient 被计算为:

其中 , 分别是 real/fake UNet 从加噪图像预测的干净图像。

3.2 改进一:去除 Regression Loss + TTUR 稳定训练

原始 DMD 需要:

DMD2 直接去掉 ,但用 TTUR 稳定训练:每 1 次 generator 更新对应 5 次 fake score 更新。

# 伪代码:Two Time-scale Update Rule
for iteration in range(total_iterations):
    # ---- Fake Score / Guidance 更新 (5 次) ----
    for _ in range(5):  # fake_score_updates_per_generator_step
        z = sample_noise()
        fake_images = generator(z).detach()  # 不回传梯度到 generator
 
        # 1. Denoising score matching loss: 让 fake_unet 学会对 fake_images 去噪
        t = sample_timestep()
        noisy_fake = add_noise(fake_images, t)
        pred_noise = fake_unet(noisy_fake, t)
        loss_denoise = MSE(pred_noise, added_noise)
 
        # 2. GAN discriminator loss: 区分真假
        logits_real = discriminator(real_images)
        logits_fake = discriminator(fake_images)
        loss_disc = softplus(logits_fake).mean() + softplus(-logits_real).mean()
 
        update(fake_unet, discriminator, loss_denoise + loss_disc)
 
    # ---- Generator 更新 (1 次) ----
    z = sample_noise()
    fake_images = generator(z)
 
    # Distribution Matching gradient
    t = sample_timestep()
    noisy = add_noise(fake_images, t)
    pred_real = real_unet(noisy, t).detach()   # frozen teacher
    pred_fake = fake_unet(noisy, t).detach()   # 只用来提供梯度方向
    grad = normalize(pred_fake - pred_real)
    loss_dm = 0.5 * MSE(fake_images, (fake_images - grad).detach())
 
    # GAN generator loss
    logits_fake = discriminator(fake_images)
    loss_gen = softplus(-logits_fake).mean()
 
    update(generator, loss_dm + loss_gen)

3.3 改进二:GAN Loss 集成

GAN objective 采用标准 non-saturating loss,discriminator 加在 fake UNet 的 bottleneck 特征上:

Discriminator 架构极其简洁——在 fake UNet bottleneck(1280 维)上加一个小型 classification head:

# GAN Discriminator: 在 UNet bottleneck 上的分类头
self.cls_pred_branch = nn.Sequential(
    nn.Conv2d(1280, 1280, kernel_size=4, stride=2, padding=1),
    nn.GroupNorm(32, 1280),
    nn.SiLU(),
    nn.Conv2d(1280, 1, kernel_size=1, stride=1, padding=0),
)

设计亮点:

  • 复用 fake UNet 的特征提取能力,不需要独立 discriminator
  • 对加噪图像判别(而非干净图像),天然利用 diffusion 的 noise injection 稳定 GAN 训练
  • GAN loss 独立于 teacher,不受 teacher score 近似误差影响

3.4 改进三:Multi-step Generator + Backward Simulation

Multi-step 推理:固定 个 timestep ,从纯噪声 开始,交替做去噪和加噪:

4-step model 使用 schedule: (对应 1000 步 teacher)。

Backward Simulation 训练

# 伪代码:Backward Simulation for Multi-step Training
def backward_simulation_training(generator, real_unet, fake_unet, discriminator):
    z = sample_noise()  # 纯噪声起点
 
    # 从 denoising_step_list 中随机选择一个起始步
    selected_step = random.choice(range(len(denoising_step_list)))
 
    # 模拟推理:student 自身做几步 backward sampling(no grad)
    noisy_image = z
    with torch.no_grad():
        for i in range(selected_step):
            t = denoising_step_list[i]
            pred_noise = generator(noisy_image, t)
            pred_clean = get_x0_from_noise(noisy_image, pred_noise, t)
            # 加噪到下一步
            next_t = denoising_step_list[i + 1]
            noisy_image = add_noise(pred_clean, next_t)
 
    # 在模拟得到的中间步上训练(有梯度)
    t_current = denoising_step_list[selected_step]
    pred_noise = generator(noisy_image, t_current)
    pred_clean = get_x0_from_noise(noisy_image, pred_noise, t_current)
 
    # 对 pred_clean 计算 DM loss + GAN loss(同 3.2 节)
    loss = compute_dm_loss(pred_clean, real_unet, fake_unet, t_current)
    loss += compute_gan_loss(pred_clean, discriminator, t_current)
    loss.backward()

3.5 完整训练流程总结

DMD2 的完整 loss 为:

训练交替执行:

  1. Guidance turn (5x):更新 fake UNet + discriminator head
  2. Generator turn (1x):更新 student generator

代码对应 (main/train_sd.py train_one_step):

def train_one_step(self):
    # Generator turn
    noise = torch.randn(batch_size, 4, latent_h, latent_w)
    loss_dict, log_dict = self.model(
        noise, text_embedding, real_image,
        generator_turn=True, guidance_turn=False
    )
    loss = loss_dict["loss_dm"] + args.cls_loss_weight * loss_dict["gen_cls_loss"]
    loss.backward()
    generator_optimizer.step()
 
    # Guidance turn (called more frequently via outer loop)
    loss_dict, log_dict = self.model(
        noise, text_embedding, real_image,
        generator_turn=False, guidance_turn=True
    )
    loss = loss_dict["loss_fake_mean"] + args.cls_loss_weight * loss_dict["loss_cls"]
    loss.backward()
    guidance_optimizer.step()

代码对应 (main/sd_guidance.py DM loss 计算):

# Distribution Matching Loss 核心计算
pred_real_noise = predict_noise(self.real_unet, noisy_latents, timesteps, ...)
pred_real_image = get_x0_from_noise(noisy_latents, pred_real_noise, alphas_cumprod, timesteps)
 
pred_fake_noise = predict_noise(self.fake_unet, noisy_latents, timesteps, ...)
pred_fake_image = get_x0_from_noise(noisy_latents, pred_fake_noise, alphas_cumprod, timesteps)
 
p_real = latents - pred_real_image   # real score direction
p_fake = latents - pred_fake_image   # fake score direction
 
# Normalized gradient
grad = (p_real - p_fake) / torch.abs(p_real).mean(dim=[1,2,3], keepdim=True)
loss_dm = 0.5 * F.mse_loss(original_latents, (original_latents - grad).detach())

代码对应 (main/sd_guidance.py GAN discriminator):

# GAN 判别
rep = self.fake_unet.forward(noisy_latents, timesteps, ..., classify_mode=True)
logits = self.cls_pred_branch(rep).squeeze(dim=[2, 3])
 
# Discriminator loss
loss_cls = F.softplus(pred_realism_on_fake).mean() + F.softplus(-pred_realism_on_real).mean()
 
# Generator loss
loss_gen = F.softplus(-pred_realism_on_fake_with_grad).mean()

代码对应 (main/sd_unified_model.py Backward Simulation):

def sample_backward(self, noisy_image, text_embedding, pooled_text_embedding):
    for constant in self.denoising_step_list[:selected_step]:
        generated_noise = self.feedforward_model(noisy_image, constant, ...)
        generated_image = get_x0_from_noise(noisy_image, generated_noise, ...)
        noisy_image = self.noise_scheduler.add_noise(generated_image, noise, next_step)
    return noisy_image

4. Experimental Setup (实验设置)

4.1 Class-conditional Generation (ImageNet-64x64)

  • Teacher: EDM (Karras et al., 2022),511 步 ODE sampler
  • Student: 与 teacher 同架构的 UNet,one-step generator
  • 训练数据: ImageNet-64x64 训练集
  • 评估: FID(50K 生成 vs 全部训练集)

4.2 Text-to-Image Synthesis (COCO 2014)

  • Teacher: SDXL / Stable Diffusion v1.5
  • Student: 与 teacher 同架构,支持 1-step 和 4-step
  • 训练 prompts: 3M subset from LAION-Aesthetics
  • GAN 真实图像: 500K from LAION-Aesthetic
  • 评估: FID、Patch FID(299x299 center crop)、CLIP Score(10K zero-shot COCO prompts)
  • 4-step schedule:

4.3 Baselines

  • Trajectory-based: Progressive Distillation, Consistency Model, iCT-deep, CTM
  • Distribution-based: DMD (original)
  • GAN-based: SDXL-Turbo (ADD), SDXL-Lightning, LCM-SDXL
  • Hybrid: TRACT, BOOT, Diff-Instruct

4.4 Human Evaluation

  • 128 prompts from PartiPrompts
  • 5 位评估者对比 Image Quality 和 Prompt Alignment
  • 与 SDXL-LCM、SDXL-Turbo、SDXL-Lightning、SDXL Teacher 对比

5. Experimental Results (实验结果)

5.1 ImageNet-64x64 Class-conditional Generation

Method# Fwd PassFID
BigGAN-deep14.06
StyleGAN-XL11.52
Consistency Model16.20
iCT-deep13.25
CTM11.92
DMD (original)12.62
DMD2 (Ours)11.51
DMD2 +longer training11.28
EDM Teacher (ODE)5112.32
EDM Teacher (SDE)5111.36

关键发现

  • DMD2 one-step FID 1.28,超越了 511 步 teacher(ODE: 2.32),是首个 one-step generator 在 ImageNet-64x64 上超越 teacher 的方法
  • 比 original DMD 提升 1.34 FID points(2.62 1.28)
  • 比同为 one-step 的 CTM 提升 0.64 points

5.2 SDXL Text-to-Image (COCO 2014)

Method# StepsFIDPatch FIDCLIP
LCM-SDXL1 / 481.62 / 22.16154.40 / 33.920.275 / 0.317
SDXL-Turbo1 / 424.57 / 23.1923.94 / 23.270.337 / 0.334
SDXL-Lightning1 / 423.92 / 24.4631.65 / 24.560.316 / 0.323
DMD2 (Ours)1 / 419.01 / 19.3226.98 / 20.860.336 / 0.332
SDXL Teacher (cfg=6)10019.3621.380.332
SDXL Teacher (cfg=8)10020.3923.210.335

关键发现

  • DMD2 4-step FID 19.32,接近 100 步 teacher(19.36),仅用 4% 的计算量
  • 1-step FID 19.01 甚至优于 teacher,可能因为 GAN loss 修正了 teacher 的 score 误差
  • Patch FID 20.86 超越所有竞争方法和 teacher,表明生成的高频细节质量更高

Figure 5 解读:用户偏好评估(128 prompts x 5 评估者)。左图 Image Quality:DMD2 在所有对比中都获得更高偏好率——vs SDXL-LCM(76.6%)、vs SDXL-Turbo(60.5%)、vs SDXL-Lightning(53.4%)、甚至 vs 50 步 SDXL Teacher(62.0%)。右图 Prompt Alignment:DMD2 同样全面领先,但优势略小——vs Teacher 为 50.5% 略高于 49.5%。所有蒸馏模型均使用 4 步,teacher 使用 50 步。这证明 DMD2 在 25 倍加速下不仅不降质,还能超越 teacher。

5.3 Ablation Studies

ImageNet Ablation (Table 3):

DMDNo Regress.TTURGANFID
yes2.62
yesyes3.48
yesyesyes2.61
yesyesyesyes1.51
yes2.56
yesyes2.52
  • 直接去掉 regression loss:FID 从 2.62 恶化到 3.48(不稳定)
  • 加入 TTUR:恢复到 2.61,匹配原始 DMD(无需数据集构建)
  • 加入 GAN:大幅提升到 1.51,比单独 GAN(2.56)或 DM+TTUR(2.61)都好
  • DM + GAN 的组合效果远超各自单独使用

SDXL Ablation (Table 4):

MethodFIDPatch FIDCLIP
w/o GAN26.9027.660.328
w/o Distribution Matching13.7727.960.307
w/o Backward Simulation20.6624.210.332
DMD2 (Full)19.3220.860.332
  • 去掉 GAN:质量大幅下降(FID 19.32 26.90),图像过度饱和/模糊
  • 去掉 DM:FID 看似不错(13.77)但 CLIP score 极低(0.307),说明 mode collapse / 缺乏多样性
  • 去掉 Backward Simulation:Patch FID 恶化(20.86 24.21),高频细节退化

Figure 7 解读:SDXL 定性 ablation。四列分别是 DMD2 (Full)、w/o Distribution Matching、w/o GAN、w/o Backward Simulation。五行不同 prompt。去掉 DM(第二列):图像风格/构图偏差大,文本对齐差。去掉 GAN(第三列):图像过度饱和、过度平滑、缺乏真实感。去掉 Backward Simulation(第四列):细节质量下降但整体可接受。完整 DMD2(第一列)在真实感、细节和文本对齐上都最佳。

Figure 6 解读:与竞争方法的视觉对比。四行不同 prompt,五列分别是 DMD2、LCM、Turbo、Lightning、Teacher。所有蒸馏模型 4 步,teacher 50 步。DMD2 在真实感和 prompt alignment 上显著优于其他蒸馏方法:如 “llama wearing sunglasses on a spaceship” DMD2 生成的太空背景更逼真;“shiba inu wearing a beret” DMD2 的毛发纹理更自然;“young girl playing piano” DMD2 的构图和人物表情更生动。

Figure 1 解读:DMD2 4-step generator 蒸馏自 SDXL 生成的 1024x1024 高质量样本展示。包含人像(小女孩、老人)、动物(骑马、金刚、树懒)、场景(纽约时代广场泰迪熊、帝国大厦)、食物(西瓜)等多种类别,展示了方法在多样化 prompt 下生成高分辨率、高质量图像的能力。

Figure (ImageNet Visual) 解读:ImageNet-64x64 上的 one-step 生成样本。包含不同类别(火烈鸟、金鱼、蘑菇、蝴蝶、教堂、狗等),展示了方法在 class-conditional 生成中的高质量和多样性。FID 1.28 的定量结果在这些样本中得到了直观验证。

5.4 SD v1.5 结果

DMD2 同样适用于较小模型:在 SD v1.5 上实现 one-step FID 8.35(COCO 2014 zero-shot),比原始 DMD 的 11.49 提升 3.14 points,也超越了使用 50 步 PNDM sampler 的 teacher model。


总结

DMD2 通过三个关键改进将 Distribution Matching Distillation 推向新高度:

  1. 去除 regression loss + TTUR:消除了昂贵的数据集构建需求(节省约 700 A100-days for SDXL),同时通过 two time-scale update rule 保持训练稳定性。核心洞察是 fake critic 需要比 generator 更快地更新以准确跟踪非平稳分布。

  2. GAN loss 集成:在 fake UNet bottleneck 上加 classification head,直接从真实数据获取监督信号,突破了 teacher score function 近似误差的限制,使 student 首次能够超越 teacher。

  3. Backward Simulation:训练时模拟推理时的 backward process,消除 multi-step 蒸馏中的 train/test domain gap,提升高频细节质量。

最终实现:ImageNet-64x64 one-step FID 1.28(超越 511 步 teacher),SDXL 4-step FID 19.32(匹敌 100 步 teacher),推理加速 500 倍

Limitations:(1) 图像多样性略低于 teacher(diversity-quality tradeoff);(2) SDXL 仍需 4 步才能匹配 teacher 质量;(3) 使用固定 guidance scale,缺乏灵活性;(4) 训练仍需大量计算资源(但已比 DMD 大幅降低)。