1. Motivation (研究动机)

当前 Self-Supervised Learning (SSL) 的经验成功很强,但理论解释碎片化:SimCLR / CPC 常被解释为 Mutual Information (MI) maximization,VICReg 依赖 variance-covariance regularization,BYOL / SimSiam / JEPA 又依赖 predictor + stopgrad;这些解释很难回答同一个问题:为什么这些看似不同的目标都能避免 collapse,并学到有用 latent representation?

这篇论文要解决的具体目标是:把 contrastive、non-contrastive、predictive、stopgrad SSL 统一成同一个 Latent Distribution Matching (LDM) 问题,即让 encoder 诱导的 latent distribution 匹配一个人为指定的 latent model 。在这个视角下,alignment 来自 latent model likelihood,uniformity / anti-collapse 来自 latent entropy。

这个问题值得研究,因为一旦 SSL 被表述为可设计的 latent probability model,就不只是事后解释已有方法,而可以系统设计新目标。例如论文从 LDM 推出 Kalman-based predictive SSL,并证明在 mild assumptions 下 predictive LDM 能识别 true latent variables up to affine transformations。

2. Idea (核心思想)

核心 insight:SSL 不是本质上在最大化 MI,而是在 latent space 里做 distribution matching;MI maximization 在很多情况下只是通过有偏估计器间接实现 entropy / distribution matching。真正决定 representation geometry 的是你假设的 latent model 和 entropy estimator,而不是是否显式加 MI 项。

关键创新是把 SSL 目标写成:

它和 Wang & Isola 的 alignment-uniformity 直觉不同:后者是几何解释;本文给出 statistical distribution matching 解释,并把 SimCLR、VICReg、CPC、BYOL/SimSiam、JEPA 映射到不同的 latent distribution + entropy estimator 组合。

3. Method (方法)

3.1 Overall framework:SSL = Latent Distribution Matching

Figure 1 解读:图中数据对 经过 deterministic encoder 变成 latent pair ,形成 empirical latent distribution 。训练目标不是重建 ,而是让 匹配假设的 latent model ;其中 拉近 positive pair,对应 alignment, 防止 collapse,对应 uniformity。

从 maximum likelihood / normalizing flow 出发,若 encoder 在 data manifold 上可逆,则 latent-space likelihood 与 data likelihood 等价;单变量 ICA 的目标也可写成:

推广到 paired views 后:

论文同时定义了 MI-augmented objective:

它解释了表中 “MI max.” 的含义: 表示直接优化 表示优化额外带 。本文的关键判断是:当 entropy maximization 已经推动 encoder 在 data manifold 上近似可逆时,MI 已接近最大,因此额外 MI 项通常不是 representation quality 的主因。

直觉上, 不是“真实世界生成模型”,而是 latent geometry 的 specification:如果你假设 positive pair 服从 conditional Gaussian,就会得到 VICReg-like objective;如果你假设 sphere 上的 von Mises-Fisher conditional,就会得到 SimCLR-like objective;如果你假设 temporal Gaussian predictor,就得到 predictive SSL。

3.2 Linear ICA as LDM

Figure 2 解读:图 A 显示 linear ICA 通过 unmixing matrix 让混合数据恢复到独立 latent factors;图 B 用 natural image pixel intensity 的 non-Gaussian shape 说明匹配非高斯 source distribution 可以 disentangle;图 C 说明即使 Gaussian source 也可在额外 temporal / variance assumptions 下恢复 OU processes。这个例子用最简单的 source recovery 说明 LDM 的统计含义。

对于 ,论文写出:

因为 是常数,优化 latent likelihood + entropy / log-det 项就能恢复 factors。

3.3 Existing SSL methods as LDM instances

论文的 mapping table:

SSL MethodLatent distribution Entropy estimatorMI max.
VICRegflat + conditional Gaussianparametric Gaussian entropy
SimCLRuniform + conditional von Mises-Fisherkernel density
CPCempirical + predictive conditional vMFAitchison et al. InfoNCE view
BYOL / SimSiamempirical + conditional vMFconditional entropy plugin
JEPApredictive conditional Gaussianconditional entropy plugin

VICReg-like derivation chooses:

其 MI-form objective:

而 LDM-form objective:

SimCLR-like derivation在 unit sphere 上选择:

若用 KDE 估计 entropy, 可恢复 SimCLR / InfoNCE-like 形式:

3.4 Empirical image representation comparison

Figure 3 解读:图 A 比较 CIFAR-10 learned representations 的 eigenspectrum;低双位数附近的 cutoff 与 CIFAR-10 intrinsic dimensionality 估计一致。图 B 显示不同方法在 UMAP 中的 class geometry;重点不是某个 MI 版本显著更好,而是 latent space / entropy estimator 改变 representation geometry。

3.5 Predictive LDM and Kalman-based SSL

Figure 4 解读:图 A 是 synthetic moving-dot video;图 B 是 Kalman filter predictor,把 encoder 输出的 pseudo observation 融入 hidden state ;图 C 显示用 kNN entropy 与 stopgrad entropy 都能恢复位置;图 D 显示两类 entropy estimator 的 gradient cosine similarity 随训练上升;图 E-G 把同一模型用于 rat hippocampal spike train,并把 latent covariance 转成位置不确定性。

对于 temporal latent sequence,模型因子化为:

predictive LDM 目标:

stopgrad predictive model 的一般形式:

论文证明它在梯度意义上等价于 likelihood term 减去一个 fixed predictor 估计的 conditional entropy,因此 stopgrad 不是 magic trick,而是在估计

3.6 Identifiability of predictive LDM

Figure 5 解读:图 A 说明 Gaussian prediction error 会限制 latent reparameterization,使任意 nonlinear warp 不再可行;图 B 是 nonlinear dynamical system identification task;图 C 显示训练前 recovered latent space 混乱,训练后可线性恢复 true latent space,支持 affine-identifiability 结论。

Theorem 1 假设:

并要求 encoder 在 data manifold 上可逆、predictor covers latent space、covariance non-degenerate。则在 LDM optimum,learned representation recovers true latents up to affine transformation。

经验验证的 dynamical system:

其中 是 Gaussian noise。

3.7 Supplementary figures

Figure A Intro 解读:补充图概览了 LDM 在不同 latent-variable 设置中的直觉,帮助把 main text 的 paired-view LDM 与更一般的 probabilistic representation setting 对齐。

Figure A1 解读:图中展示多个 LDM model examples,包括 simple conditional relation、temporal relation、categorical/probabilistic encodings 等,说明 LDM 不是单一 loss,而是由 latent model specification 生成的一族目标。

Figure A2a 解读:该图补充 CIFAR-10 encoder Jacobian rank 分析;各方法都达到相近 rank,支持“MI maximization 不是 representation quality 主因”的论点。

Figure A2b 解读:该图补充不同 entropy objective 的 gradient similarity,显示 single / joint entropy 版本在优化方向上高度相关。

Figure A2c 解读:该图补充更多 representation geometry / linear probing 结果,用来支撑主文“latent space 与 entropy estimator 比 MI flag 更关键”的结论。

Figure A3 解读:该图展示 square movement task;Kalman-based SSL 可把 moving dot 的 observation 映射到可线性解码的位置状态。

Figure A4 解读:该图展示 input-dependent observation noise 的 Kalman SSL;根据输入估计 observation covariance,使模型能给出 uncertainty-aware filtering。

Figure A5 解读:该图比较 Kalman-based SSL latent state 解码位置与直接从 spikes 预测位置;latent dynamics model 的 MLP decoder ,明显优于直接 MLP 和直接 linear

Figure A6 解读:该图说明在 nonlinear identification task 中,stopgrad entropy estimator 只能 approximate identification;相比之下,显式 kNN entropy 更符合 Theorem 1 的 Gaussian predictive LDM 条件。

3.8 Pseudocode from actual implementation

import torch
import torch.nn.functional as F
 
def gaussian_ldm_loss(z1, z2, prediction_precision=0.1, entropy_type="dual_knn"):
    """Based on image-representations/solo/losses/gaussprob.py."""
    z1 = gather_across_devices(z1)
    z2 = gather_across_devices(z2)
    invariance = (-(z1 - z2).pow(2).mean(dim=1) * prediction_precision).mean()
 
    if entropy_type == "dual_knn":
        joint = torch.cat([z1, z2], dim=1)
        full_entropy = knn_entropy(joint, k=3, p=2, clip_quantile=0.9)
        loss = -invariance - full_entropy
    else:
        h1 = kde_or_logdet_entropy(z1)
        h2 = kde_or_logdet_entropy(z2)
        single_entropy = 0.5 * (h1 + h2)
        loss = -invariance - single_entropy
    return loss
def knn_entropy(x, k=3, p=2, clip_quantile=0.9, eps=1e-8):
    """Kozachenko-Leonenko-style minibatch entropy used by gaussprob/sphereprob."""
    dist = torch.cdist(x, x, p=p)
    nn_dist, _ = torch.topk(dist, k + 1, largest=False)
    nn_dist = nn_dist[:, 1:][:, -1:]
    nn_dist = torch.clamp(nn_dist, min=0.0, max=1e6)
    upper = torch.quantile(nn_dist, clip_quantile).detach()
    nn_dist = torch.clamp(nn_dist, max=upper)
    return torch.log(nn_dist + eps).mean()
def kalman_ssl_forward(observations, encoder, A, D, Sigma_A, Sigma_D):
    """Based on predictive-models/probssl/methods/kalmanSSL.py."""
    z_est = torch.zeros(observations.size(0), A.size(0), device=observations.device)
    Sigma_est = torch.eye(A.size(0), device=observations.device).repeat(observations.size(0), 1, 1)
    z_inf_seq = encoder(observations)
    preds, pred_covs, inf_covs = [], [], []
 
    for t in range(observations.size(1)):
        z_pred = torch.einsum("ij,bj->bi", A, z_est)
        z_inf_pred = torch.einsum("ij,bj->bi", D, z_pred)
        z_inf = z_inf_seq[:, t]
        error = z_inf - z_inf_pred
 
        Sigma_pred = A @ Sigma_est @ A.T + Sigma_A
        Sigma_z_pred = D @ Sigma_pred @ D.T
        Sigma_error = Sigma_z_pred + Sigma_D
        K = Sigma_pred @ D.T @ torch.inverse(Sigma_error)
        z_est = z_pred + torch.einsum("bij,bj->bi", K, error)
        Sigma_est = (torch.eye(A.size(0), device=observations.device) - K @ D) @ Sigma_pred
 
        preds.append(z_inf_pred)
        pred_covs.append(Sigma_z_pred)
        inf_covs.append(Sigma_D)
    return torch.stack(z_inf_seq.unbind(1), 1), torch.stack(preds, 1), torch.stack(inf_covs, 1), torch.stack(pred_covs, 1)
def kalman_ssl_loss(zs_inf, zs_pred, zs_inf_cov, zs_pred_cov):
    """Based on predictive-models/probssl/losses/ssl/kalmanSSL.py."""
    loss = 0.0
    for t in range(zs_inf.size(1)):
        Sigma_e = zs_pred_cov[:, t] + zs_inf_cov[:, t]
        target = zs_inf[:, t].detach()
        mean = zs_pred[:, t]
        alignment = -gaussian_cross_entropy(mean, target, Sigma_e).mean()
        loss = loss - alignment
    return loss

Code reference: main @ 01d2ae83 (2026-05-06) — pseudocode and mapping based on this commit

Paper ConceptSource FileKey Class/Function
Gaussian LDM image representationimage-representations/solo/losses/gaussprob.pygaussprob_loss_func_*, gauss_similarity, kozachenko_leonenko_*
Spherical / vMF LDM image representationimage-representations/solo/losses/sphereprob.pysphereprob_loss_func_*, sphere_similarity, log_det_expansion
Training wrapper for Gaussian LDMimage-representations/solo/methods/gaussprob.pyGaussProb.training_step
Training wrapper for spherical LDMimage-representations/solo/methods/sphereprob.pySphereProb.training_step
Kalman predictive SSL encoder/predictorpredictive-models/probssl/methods/kalmanSSL.pyKalmanSSLEncoder.forward, KalmanSSL.training_step
Kalman alignment losspredictive-models/probssl/losses/ssl/kalmanSSL.pykalmanSSL_loss_func
Synthetic moving-dot datasetspredictive-models/probssl/data/dot_motion.pyDotMotion, trajectory generators

4. Experimental Setup (实验设置)

  • Datasets used and scale:
    • CIFAR-10、CIFAR-100、SVHN、Imagenet-100:用于 SSL pretraining + linear probing;论文未在正文列出样本数,采用 standard solo-learn setup。
    • Synthetic moving-dot videos:800 training sequences,10×10 pixel resolution,20 epochs。
    • Rat hippocampal hc-11 dataset:120 neurons,25 ms spike-count bins,约 45 min spike train;每 epoch 采样 100000 windows,每个 window 约 10 s,训练 16 epochs。
    • Nonlinear system identification:10000 sequences,300 epochs;高维补充实验使用 100000 samples,20 epochs。
  • Baselines / variants:
    • LDM variants: Plane/Sphere latent space × Contrastive/KNN/LogDet entropy × or
    • Named equivalents: VICReg = Plane-LogDet-;SimCLR = Sphere-Contr.-
    • Predictive variants: Kalman SSL with kNN entropy, stopgrad entropy, direct spike-to-position prediction, nonlinear predictor identification variants。
  • Evaluation metrics:
    • Linear probing Top-1 / Top-5 validation accuracy:冻结 representation 后训练 linear classifier。
    • :linear / MLP decoder 对 true position 或 latent state 的解释方差。
    • Gradient cosine similarity:比较不同 entropy objective 对网络权重梯度方向是否一致。
    • Jacobian rank / eigenspectrum / UMAP:分析 representation dimensionality 与 class geometry。
  • Training config:
    • Image representation: solo-learn extended implementation,ResNet-18 encoder;CIFAR 1000 epochs,Imagenet-100 400 epochs;三次 runs 计算 standard deviation。
    • kNN entropy: Euclidean metric ,丢弃 upper 10% kNN outliers。
    • Kalman synthetic: MLP encoder,1 hidden layer,100 units;diagonal ;Kalman filter faster timescale than encoder。
    • hc-11: 8D observation,16D latent state;position uncertainty 通过从 采样 1000 次估计 95% CI。
    • Nonlinear identification: encoder MLP 10 hidden layers × 200 units;predictor = 1-layer LSTM with 10 hidden units + 2-layer MLP head with 20 hidden units。

5. Experimental Results (实验结果)

5.1 Main linear probing results

SpaceEntropyMICIFAR-10 Top-1CIFAR-100 Top-1Imagenet-100 Top-1
PlaneContr.91.9±0.165.3±0.372.6
PlaneContr.92.1±0.165.3±0.272.7
PlanekNN92.1±0.265.6±0.474.3
PlanekNN91.9±0.165.8±0.473.7
PlaneLogDet92.1±0.169.5±0.275.9
PlaneLogDet91.9±0.168.6±0.174.7
SphereContr.90.9±0.164.8±0.272.1
SphereContr.91.4±0.266.0±0.273.1
SpherekNN90.2±0.064.3±0.172.6
SpherekNN90.0±0.264.5±0.273.3
SphereLogDet91.4±0.265.4±0.273.0
SphereLogDet91.2±0.165.4±0.272.3

主要结论:with / without MI maximization 没有系统性差异;影响更大的是 latent space 与 entropy estimator。Plane-LogDet- 在 CIFAR-100 和 Imagenet-100 上最好,说明直接 LDM entropy matching 不必依赖 MI term。

5.2 Predictive SSL results

  • Synthetic Kalman task:CPC-like MI maximization 与 predictive stopgrad SSL 都恢复 true latent variables,linear probing
  • hc-11 hippocampal spikes:Kalman latent state 排列成 position/direction circle,并能输出位置不确定性;补充图显示 latent MLP decoder ,latent linear decoder ,直接 spike MLP ,直接 spike linear
  • Nonlinear system identification:predictive LDM 训练后 recovered latent space 可线性恢复 true position,
  • Square movement 补充实验:contrastive 和 stopgrad predictive SSL 均找到好 encoding,;input-dependent observation noise 补充实验中 hidden variables 轨迹估计

5.3 Ablations / key findings

  • MI maximization 的边际贡献小:论文从理论上指出 MI 对 invertible transforms 不敏感;实验上 的 linear probing、eigenspectrum、Jacobian rank、gradient direction 都相近。
  • Entropy estimator 更关键:KDE 对应 contrastive,LogDet 对应 non-contrastive covariance regularization,kNN 给出 sampling-free / contrastive-adjacent 估计;这些选择明显改变性能和 geometry。
  • stopgrad 的角色被重新解释:它近似条件熵估计,而不是单纯 heuristic collapse avoidance。
  • Identifiability 来自 predictive residual distribution:Gaussian prediction error 排除了任意 nonlinear reparameterization,使 representation 只能差一个 affine transform。

5.4 Limitations

作者明确提到:BYOL/SimSiam derivation 使用 empirical prior,future work 可以探索 proper prior,以减少额外 model regularization 的需求。另一个限制是 identifiability 结论依赖 encoder 在 data manifold 上可逆、predictor coverage、noise / predictive distribution assumptions;在 nonlinear scenario 中,stopgrad estimator 只得到 approximate identification。论文的实验主要是 controlled image / synthetic dynamics / spike-train setup;我推断这意味着其大规模 foundation-model 泛化仍需额外验证。

总体结论:LDM 给 SSL 提供了比 MI maximization 更直接的统一解释;它能解释已有方法、设计新 predictive SSL、并给出 affine-identifiability 保证。对实践者来说,关键设计旋钮是 latent model 与 entropy estimator,而不是盲目追求 MI objective。