1. Motivation (研究动机)
当前 Self-Supervised Learning (SSL) 的经验成功很强,但理论解释碎片化:SimCLR / CPC 常被解释为 Mutual Information (MI) maximization,VICReg 依赖 variance-covariance regularization,BYOL / SimSiam / JEPA 又依赖 predictor + stopgrad;这些解释很难回答同一个问题:为什么这些看似不同的目标都能避免 collapse,并学到有用 latent representation?
这篇论文要解决的具体目标是:把 contrastive、non-contrastive、predictive、stopgrad SSL 统一成同一个 Latent Distribution Matching (LDM) 问题,即让 encoder 诱导的 latent distribution 匹配一个人为指定的 latent model 。在这个视角下,alignment 来自 latent model likelihood,uniformity / anti-collapse 来自 latent entropy。
这个问题值得研究,因为一旦 SSL 被表述为可设计的 latent probability model,就不只是事后解释已有方法,而可以系统设计新目标。例如论文从 LDM 推出 Kalman-based predictive SSL,并证明在 mild assumptions 下 predictive LDM 能识别 true latent variables up to affine transformations。
2. Idea (核心思想)
核心 insight:SSL 不是本质上在最大化 MI,而是在 latent space 里做 distribution matching;MI maximization 在很多情况下只是通过有偏估计器间接实现 entropy / distribution matching。真正决定 representation geometry 的是你假设的 latent model 和 entropy estimator,而不是是否显式加 MI 项。
关键创新是把 SSL 目标写成:
它和 Wang & Isola 的 alignment-uniformity 直觉不同:后者是几何解释;本文给出 statistical distribution matching 解释,并把 SimCLR、VICReg、CPC、BYOL/SimSiam、JEPA 映射到不同的 latent distribution + entropy estimator 组合。
3. Method (方法)
3.1 Overall framework:SSL = Latent Distribution Matching
Figure 1 解读:图中数据对 经过 deterministic encoder 变成 latent pair ,形成 empirical latent distribution 。训练目标不是重建 ,而是让 匹配假设的 latent model ;其中 拉近 positive pair,对应 alignment, 防止 collapse,对应 uniformity。
从 maximum likelihood / normalizing flow 出发,若 encoder 在 data manifold 上可逆,则 latent-space likelihood 与 data likelihood 等价;单变量 ICA 的目标也可写成:
推广到 paired views 后:
论文同时定义了 MI-augmented objective:
它解释了表中 “MI max.” 的含义: 表示直接优化 , 表示优化额外带 的 。本文的关键判断是:当 entropy maximization 已经推动 encoder 在 data manifold 上近似可逆时,MI 已接近最大,因此额外 MI 项通常不是 representation quality 的主因。
直觉上, 不是“真实世界生成模型”,而是 latent geometry 的 specification:如果你假设 positive pair 服从 conditional Gaussian,就会得到 VICReg-like objective;如果你假设 sphere 上的 von Mises-Fisher conditional,就会得到 SimCLR-like objective;如果你假设 temporal Gaussian predictor,就得到 predictive SSL。
3.2 Linear ICA as LDM
Figure 2 解读:图 A 显示 linear ICA 通过 unmixing matrix 让混合数据恢复到独立 latent factors;图 B 用 natural image pixel intensity 的 non-Gaussian shape 说明匹配非高斯 source distribution 可以 disentangle;图 C 说明即使 Gaussian source 也可在额外 temporal / variance assumptions 下恢复 OU processes。这个例子用最简单的 source recovery 说明 LDM 的统计含义。
对于 ,论文写出:
因为 是常数,优化 latent likelihood + entropy / log-det 项就能恢复 factors。
3.3 Existing SSL methods as LDM instances
论文的 mapping table:
| SSL Method | Latent distribution | Entropy estimator | MI max. |
|---|---|---|---|
| VICReg | flat + conditional Gaussian | parametric Gaussian entropy | |
| SimCLR | uniform + conditional von Mises-Fisher | kernel density | |
| CPC | empirical + predictive conditional vMF | Aitchison et al. InfoNCE view | |
| BYOL / SimSiam | empirical + conditional vMF | conditional entropy plugin | |
| JEPA | predictive conditional Gaussian | conditional entropy plugin |
VICReg-like derivation chooses:
其 MI-form objective:
而 LDM-form objective:
SimCLR-like derivation在 unit sphere 上选择:
若用 KDE 估计 entropy, 可恢复 SimCLR / InfoNCE-like 形式:
3.4 Empirical image representation comparison
Figure 3 解读:图 A 比较 CIFAR-10 learned representations 的 eigenspectrum;低双位数附近的 cutoff 与 CIFAR-10 intrinsic dimensionality 估计一致。图 B 显示不同方法在 UMAP 中的 class geometry;重点不是某个 MI 版本显著更好,而是 latent space / entropy estimator 改变 representation geometry。
3.5 Predictive LDM and Kalman-based SSL
Figure 4 解读:图 A 是 synthetic moving-dot video;图 B 是 Kalman filter predictor,把 encoder 输出的 pseudo observation 融入 hidden state ;图 C 显示用 kNN entropy 与 stopgrad entropy 都能恢复位置;图 D 显示两类 entropy estimator 的 gradient cosine similarity 随训练上升;图 E-G 把同一模型用于 rat hippocampal spike train,并把 latent covariance 转成位置不确定性。
对于 temporal latent sequence,模型因子化为:
predictive LDM 目标:
stopgrad predictive model 的一般形式:
论文证明它在梯度意义上等价于 likelihood term 减去一个 fixed predictor 估计的 conditional entropy,因此 stopgrad 不是 magic trick,而是在估计 。
3.6 Identifiability of predictive LDM
Figure 5 解读:图 A 说明 Gaussian prediction error 会限制 latent reparameterization,使任意 nonlinear warp 不再可行;图 B 是 nonlinear dynamical system identification task;图 C 显示训练前 recovered latent space 混乱,训练后可线性恢复 true latent space,支持 affine-identifiability 结论。
Theorem 1 假设:
并要求 encoder 在 data manifold 上可逆、predictor covers latent space、covariance non-degenerate。则在 LDM optimum,learned representation recovers true latents up to affine transformation。
经验验证的 dynamical system:
其中 ,,, 是 Gaussian noise。
3.7 Supplementary figures
Figure A Intro 解读:补充图概览了 LDM 在不同 latent-variable 设置中的直觉,帮助把 main text 的 paired-view LDM 与更一般的 probabilistic representation setting 对齐。
Figure A1 解读:图中展示多个 LDM model examples,包括 simple conditional relation、temporal relation、categorical/probabilistic encodings 等,说明 LDM 不是单一 loss,而是由 latent model specification 生成的一族目标。
Figure A2a 解读:该图补充 CIFAR-10 encoder Jacobian rank 分析;各方法都达到相近 rank,支持“MI maximization 不是 representation quality 主因”的论点。
Figure A2b 解读:该图补充不同 entropy objective 的 gradient similarity,显示 single / joint entropy 版本在优化方向上高度相关。
Figure A2c 解读:该图补充更多 representation geometry / linear probing 结果,用来支撑主文“latent space 与 entropy estimator 比 MI flag 更关键”的结论。
Figure A3 解读:该图展示 square movement task;Kalman-based SSL 可把 moving dot 的 observation 映射到可线性解码的位置状态。
Figure A4 解读:该图展示 input-dependent observation noise 的 Kalman SSL;根据输入估计 observation covariance,使模型能给出 uncertainty-aware filtering。
Figure A5 解读:该图比较 Kalman-based SSL latent state 解码位置与直接从 spikes 预测位置;latent dynamics model 的 MLP decoder ,明显优于直接 MLP 和直接 linear 。
Figure A6 解读:该图说明在 nonlinear identification task 中,stopgrad entropy estimator 只能 approximate identification;相比之下,显式 kNN entropy 更符合 Theorem 1 的 Gaussian predictive LDM 条件。
3.8 Pseudocode from actual implementation
import torch
import torch.nn.functional as F
def gaussian_ldm_loss(z1, z2, prediction_precision=0.1, entropy_type="dual_knn"):
"""Based on image-representations/solo/losses/gaussprob.py."""
z1 = gather_across_devices(z1)
z2 = gather_across_devices(z2)
invariance = (-(z1 - z2).pow(2).mean(dim=1) * prediction_precision).mean()
if entropy_type == "dual_knn":
joint = torch.cat([z1, z2], dim=1)
full_entropy = knn_entropy(joint, k=3, p=2, clip_quantile=0.9)
loss = -invariance - full_entropy
else:
h1 = kde_or_logdet_entropy(z1)
h2 = kde_or_logdet_entropy(z2)
single_entropy = 0.5 * (h1 + h2)
loss = -invariance - single_entropy
return lossdef knn_entropy(x, k=3, p=2, clip_quantile=0.9, eps=1e-8):
"""Kozachenko-Leonenko-style minibatch entropy used by gaussprob/sphereprob."""
dist = torch.cdist(x, x, p=p)
nn_dist, _ = torch.topk(dist, k + 1, largest=False)
nn_dist = nn_dist[:, 1:][:, -1:]
nn_dist = torch.clamp(nn_dist, min=0.0, max=1e6)
upper = torch.quantile(nn_dist, clip_quantile).detach()
nn_dist = torch.clamp(nn_dist, max=upper)
return torch.log(nn_dist + eps).mean()def kalman_ssl_forward(observations, encoder, A, D, Sigma_A, Sigma_D):
"""Based on predictive-models/probssl/methods/kalmanSSL.py."""
z_est = torch.zeros(observations.size(0), A.size(0), device=observations.device)
Sigma_est = torch.eye(A.size(0), device=observations.device).repeat(observations.size(0), 1, 1)
z_inf_seq = encoder(observations)
preds, pred_covs, inf_covs = [], [], []
for t in range(observations.size(1)):
z_pred = torch.einsum("ij,bj->bi", A, z_est)
z_inf_pred = torch.einsum("ij,bj->bi", D, z_pred)
z_inf = z_inf_seq[:, t]
error = z_inf - z_inf_pred
Sigma_pred = A @ Sigma_est @ A.T + Sigma_A
Sigma_z_pred = D @ Sigma_pred @ D.T
Sigma_error = Sigma_z_pred + Sigma_D
K = Sigma_pred @ D.T @ torch.inverse(Sigma_error)
z_est = z_pred + torch.einsum("bij,bj->bi", K, error)
Sigma_est = (torch.eye(A.size(0), device=observations.device) - K @ D) @ Sigma_pred
preds.append(z_inf_pred)
pred_covs.append(Sigma_z_pred)
inf_covs.append(Sigma_D)
return torch.stack(z_inf_seq.unbind(1), 1), torch.stack(preds, 1), torch.stack(inf_covs, 1), torch.stack(pred_covs, 1)def kalman_ssl_loss(zs_inf, zs_pred, zs_inf_cov, zs_pred_cov):
"""Based on predictive-models/probssl/losses/ssl/kalmanSSL.py."""
loss = 0.0
for t in range(zs_inf.size(1)):
Sigma_e = zs_pred_cov[:, t] + zs_inf_cov[:, t]
target = zs_inf[:, t].detach()
mean = zs_pred[:, t]
alignment = -gaussian_cross_entropy(mean, target, Sigma_e).mean()
loss = loss - alignment
return lossCode reference:
main@01d2ae83(2026-05-06) — pseudocode and mapping based on this commit
| Paper Concept | Source File | Key Class/Function |
|---|---|---|
| Gaussian LDM image representation | image-representations/solo/losses/gaussprob.py | gaussprob_loss_func_*, gauss_similarity, kozachenko_leonenko_* |
| Spherical / vMF LDM image representation | image-representations/solo/losses/sphereprob.py | sphereprob_loss_func_*, sphere_similarity, log_det_expansion |
| Training wrapper for Gaussian LDM | image-representations/solo/methods/gaussprob.py | GaussProb.training_step |
| Training wrapper for spherical LDM | image-representations/solo/methods/sphereprob.py | SphereProb.training_step |
| Kalman predictive SSL encoder/predictor | predictive-models/probssl/methods/kalmanSSL.py | KalmanSSLEncoder.forward, KalmanSSL.training_step |
| Kalman alignment loss | predictive-models/probssl/losses/ssl/kalmanSSL.py | kalmanSSL_loss_func |
| Synthetic moving-dot datasets | predictive-models/probssl/data/dot_motion.py | DotMotion, trajectory generators |
4. Experimental Setup (实验设置)
- Datasets used and scale:
- CIFAR-10、CIFAR-100、SVHN、Imagenet-100:用于 SSL pretraining + linear probing;论文未在正文列出样本数,采用 standard solo-learn setup。
- Synthetic moving-dot videos:800 training sequences,10×10 pixel resolution,20 epochs。
- Rat hippocampal hc-11 dataset:120 neurons,25 ms spike-count bins,约 45 min spike train;每 epoch 采样 100000 windows,每个 window 约 10 s,训练 16 epochs。
- Nonlinear system identification:10000 sequences,300 epochs;高维补充实验使用 100000 samples,20 epochs。
- Baselines / variants:
- LDM variants: Plane/Sphere latent space × Contrastive/KNN/LogDet entropy × or 。
- Named equivalents: VICReg = Plane-LogDet-;SimCLR = Sphere-Contr.-。
- Predictive variants: Kalman SSL with kNN entropy, stopgrad entropy, direct spike-to-position prediction, nonlinear predictor identification variants。
- Evaluation metrics:
- Linear probing Top-1 / Top-5 validation accuracy:冻结 representation 后训练 linear classifier。
- :linear / MLP decoder 对 true position 或 latent state 的解释方差。
- Gradient cosine similarity:比较不同 entropy objective 对网络权重梯度方向是否一致。
- Jacobian rank / eigenspectrum / UMAP:分析 representation dimensionality 与 class geometry。
- Training config:
- Image representation: solo-learn extended implementation,ResNet-18 encoder;CIFAR 1000 epochs,Imagenet-100 400 epochs;三次 runs 计算 standard deviation。
- kNN entropy: Euclidean metric ,,丢弃 upper 10% kNN outliers。
- Kalman synthetic: MLP encoder,1 hidden layer,100 units;diagonal ;Kalman filter faster timescale than encoder。
- hc-11: 8D observation,16D latent state;position uncertainty 通过从 采样 1000 次估计 95% CI。
- Nonlinear identification: encoder MLP 10 hidden layers × 200 units;predictor = 1-layer LSTM with 10 hidden units + 2-layer MLP head with 20 hidden units。
5. Experimental Results (实验结果)
5.1 Main linear probing results
| Space | Entropy | MI | CIFAR-10 Top-1 | CIFAR-100 Top-1 | Imagenet-100 Top-1 |
|---|---|---|---|---|---|
| Plane | Contr. | 91.9±0.1 | 65.3±0.3 | 72.6 | |
| Plane | Contr. | 92.1±0.1 | 65.3±0.2 | 72.7 | |
| Plane | kNN | 92.1±0.2 | 65.6±0.4 | 74.3 | |
| Plane | kNN | 91.9±0.1 | 65.8±0.4 | 73.7 | |
| Plane | LogDet | 92.1±0.1 | 69.5±0.2 | 75.9 | |
| Plane | LogDet | 91.9±0.1 | 68.6±0.1 | 74.7 | |
| Sphere | Contr. | 90.9±0.1 | 64.8±0.2 | 72.1 | |
| Sphere | Contr. | 91.4±0.2 | 66.0±0.2 | 73.1 | |
| Sphere | kNN | 90.2±0.0 | 64.3±0.1 | 72.6 | |
| Sphere | kNN | 90.0±0.2 | 64.5±0.2 | 73.3 | |
| Sphere | LogDet | 91.4±0.2 | 65.4±0.2 | 73.0 | |
| Sphere | LogDet | 91.2±0.1 | 65.4±0.2 | 72.3 |
主要结论:with / without MI maximization 没有系统性差异;影响更大的是 latent space 与 entropy estimator。Plane-LogDet- 在 CIFAR-100 和 Imagenet-100 上最好,说明直接 LDM entropy matching 不必依赖 MI term。
5.2 Predictive SSL results
- Synthetic Kalman task:CPC-like MI maximization 与 predictive stopgrad SSL 都恢复 true latent variables,linear probing 。
- hc-11 hippocampal spikes:Kalman latent state 排列成 position/direction circle,并能输出位置不确定性;补充图显示 latent MLP decoder ,latent linear decoder ,直接 spike MLP ,直接 spike linear 。
- Nonlinear system identification:predictive LDM 训练后 recovered latent space 可线性恢复 true position,。
- Square movement 补充实验:contrastive 和 stopgrad predictive SSL 均找到好 encoding,;input-dependent observation noise 补充实验中 hidden variables 轨迹估计 。
5.3 Ablations / key findings
- MI maximization 的边际贡献小:论文从理论上指出 MI 对 invertible transforms 不敏感;实验上 与 的 linear probing、eigenspectrum、Jacobian rank、gradient direction 都相近。
- Entropy estimator 更关键:KDE 对应 contrastive,LogDet 对应 non-contrastive covariance regularization,kNN 给出 sampling-free / contrastive-adjacent 估计;这些选择明显改变性能和 geometry。
- stopgrad 的角色被重新解释:它近似条件熵估计,而不是单纯 heuristic collapse avoidance。
- Identifiability 来自 predictive residual distribution:Gaussian prediction error 排除了任意 nonlinear reparameterization,使 representation 只能差一个 affine transform。
5.4 Limitations
作者明确提到:BYOL/SimSiam derivation 使用 empirical prior,future work 可以探索 proper prior,以减少额外 model regularization 的需求。另一个限制是 identifiability 结论依赖 encoder 在 data manifold 上可逆、predictor coverage、noise / predictive distribution assumptions;在 nonlinear scenario 中,stopgrad estimator 只得到 approximate identification。论文的实验主要是 controlled image / synthetic dynamics / spike-train setup;我推断这意味着其大规模 foundation-model 泛化仍需额外验证。
总体结论:LDM 给 SSL 提供了比 MI maximization 更直接的统一解释;它能解释已有方法、设计新 predictive SSL、并给出 affine-identifiability 保证。对实践者来说,关键设计旋钮是 latent model 与 entropy estimator,而不是盲目追求 MI objective。