WorldCompass: Reinforcement Learning for Long-Horizon World Models

Info

论文链接:arXiv:2602.09022
项目主页:Tencent Hunyuan World
代码仓库:Tencent-Hunyuan/HY-WorldPlay

Warning

更新:远程仓库 HY-WorldPlayworldcompass/ 目录现已可访问,我通过 GitHub tree API 与原始源码进一步核对了 worldcompass/fastvideo/training/world_compass_in_train_pipeline.pyworldcompass/fastvideo/training/training_utils.pyworldcompass/reward_function/reward_function.pyworldcompass/scripts/train_worldcompass.sh。因此下文中的训练伪代码、reward 伪代码与代码映射表,已经按 worldcompass/ 下的真实实现重新校正;但这些伪代码仍然是为了阅读性做过抽象的高层概括,并非逐行转写。对于与源码存在实现细节偏差、评估调用频率差异或代码瑕疵的部分,文中会明确说明“按代码意图概括”或“源码当前存在覆盖/简化问题”,避免误导为逐行等价实现。对于 backbone/action memory 等底层机制,文中仍保留与 WorldPlay 主干代码的对应说明。

1. Motivation(研究动机)

现有 video-based world model(如 WorldPlay、Genie 系列、Cosmos 系列)已经证明:如果把 world modeling 视作带 action condition 的 autoregressive video generation,就可以让模型在“观察历史状态 + 执行动作”的交互回路中持续探索环境。但这些方法的主要优化重心仍然放在 pre-training:模型通过视频重建或下一段预测的 pixel/token supervision,隐式学习 action following 与 scene consistency。

论文指出,这种纯 pre-training 范式存在三个核心瓶颈:

  • 动作反馈过于间接:训练目标是重建下一个 clip,而不是显式奖励“是否正确执行了动作”。因此模型在动作切换、组合动作、长时域连续操控下往往会迟钝或跑偏。
  • 长时域误差积累严重:world model 本质是 autoregressive 生成,越往后生成,越容易出现视觉漂移、几何不一致和动作偏移;如果没有额外监督,模型很难在 rollout 过程中自我纠正。
  • 现有 diffusion RL 方案不适配 world model:直接套用 sequence-level reward 会导致反馈极度稀疏;而一些 diffusion RL 方法依赖同噪声下的 SDE exploration,论文认为这更容易改变“画面细节”,却难以充分探索“相机轨迹/动作执行”的差异。

因此,本文希望解决的问题是:如何为 long-horizon、autoregressive、interactive 的 world model 设计一个真正可用的 RL post-training 框架,使模型既更听动作,又不牺牲视觉质量。 这件事值得研究,因为它直接决定 world model 是否能从“能生成视频”迈向“能稳定执行长期交互任务”,同时也关系到 future agent、游戏引擎、可交互 3D content generation 等更广泛应用。

2. Idea(核心思想)

WorldCompass 的核心洞见是:world model 的 RL 不能沿用 sequence-level 的粗粒度奖励,而应当围绕 autoregressive video generation 的“当前目标 clip”建立 clip-level rollout、双奖励评价与高效负样本感知优化。

更具体地说,论文把世界模型后训练重构为三件事:

  1. 在固定历史 prefix 的前提下,只对第 个目标 clip 做一组 rollout,从而得到可比较、细粒度的 reward;
  2. 同时使用 Interaction FollowingVisual Quality 两个 reward,防止 reward hacking;
  3. 用受 DiffusionNFT 启发的 negative-aware fine-tuning 来优化 diffusion world model,并辅以 Best-of-N、timestep 子采样、progressive horizon 等效率设计。

它与现有方法的根本差别在于:不是把整个长视频作为单个样本打分,而是把 RL 对齐聚焦到“当前 clip 是否在固定上下文下正确响应动作”。这让 reward 更密、更稳定,也更符合 autoregressive inference 的真实使用方式。

3. Method(方法)

3.1 Overall framework

Figure 1 解读:这张总览图给出了 WorldCompass 的完整训练闭环。最左侧是环境条件 与动作序列 ;系统先生成共享的 prefix clips,然后只在第 个目标 clip 上做 次 rollout,得到一组候选视频片段。中间模块分别对这些候选样本计算 action-following 与 visual-quality reward;最右侧再利用 reward 归一化后的 optimality probability 去驱动 RL 优化。图中最关键的设计是“共享 prefix + 当前 clip 多样化 rollout”,它把 world model 的 long-horizon 生成问题压缩为一个局部、可比较、可优化的 clip-level decision problem。

论文把 interactive world modeling 建模为:

其中 是第 个未来 clip, 是历史观测, 是当前动作条件, 是文本或图像 prompt。与 LLM 的 sequence-level rollout 不同,WorldCompass 主张只在目标 clip 做对比采样:

这样做有两点直接收益:

  • 计算量从近似 降到 ,因为 prefix 只生成一次;
  • 不同样本共享同一历史上下文,reward 可比性更强,更适合作 advantage normalization。

3.2 Clip-level rollout 与 progressive horizon

论文在 Algorithm 1 中把训练步骤写得很清楚:每个 iteration 选择一个目标 clip index

也就是随着训练步数推进,优化目标会循环覆盖从短 horizon 到长 horizon 的不同 clip 位置。作者认为这是一种自然的 curriculum learning:模型先学会短时响应,再逐步处理更长历史依赖。

结合最新公开的 worldcompass/ 源码,可以看到论文中的 RL 后训练已经有可对应的实现入口,而底层 rollout 与 memory reconstruction 仍然复用 WorldPlay/fastvideo 体系:

  • worldcompass/fastvideo/training/world_compass_in_train_pipeline.py 中的 train()train_one_step()_sample_reference_model()_prepare_grpo_inputs()_nft_forward_and_compute_loss() 构成了真实的 RL 后训练主流程;
  • worldcompass/fastvideo/training/training_utils.py 中的 EMA_FSDP_schedule 维护了 policy_shadowckpt_shadow,其中 policy_shadow 实际扮演 old/reference policy;
  • worldcompass/reward_function/reward_function.py 中的 CompassReward.reward()CompassReward.score_video() 实现了动作精度、HPSv3 文本分、HPSv3 质量分与画质漂移分的计算;
  • hyvideo/pipelines/worldplay_video_pipeline.py 中的 ar_rollout(...) 会先缓存文本条件,再在每个 chunk 生成前重建 memory context;
  • trainer/dataset/ar_camera_hunyuan_w_mem_dataset.py 中的 select_aligned_memory_frames(...) 会根据历史相机轨迹和 FOV overlap 选择最相关的 memory frames;
  • trainer/models/hyvideo/models/transformers/ar_action_hunyuanvideo_1_5_transformer.pyhyvideo/models/transformers/worldplay_1_5_transformer.py 都展示了 action 是如何被显式注入 transformer condition 的。

下面给出基于公开代码整理出的关键实现伪代码。前两段对应底层 world model memory / rollout,后面几段对应最新可访问的 worldcompass/ RL 训练实现。

组件 A:memory frame selection(代码可验证)

def select_aligned_memory_frames(w2c_list, current_frame_idx,
                                 memory_frames=20,
                                 temporal_context_size=12,
                                 pred_latent_size=4):
    if current_frame_idx <= memory_frames:
        return list(range(current_frame_idx))
 
    context_frames = list(range(max(0, current_frame_idx - temporal_context_size),
                                current_frame_idx))
    query_clip = list(range(current_frame_idx,
                            min(current_frame_idx + pred_latent_size, len(w2c_list))))
 
    candidate_scores = []
    for hist_idx in range(4, current_frame_idx - temporal_context_size, 4):
        dist = 0.0
        for query_idx in query_clip:
            dist_1 = 1.0 - fov_overlap(w2c_list[query_idx], w2c_list[hist_idx])
            dist_2 = 1.0 - fov_overlap(w2c_list[query_idx], w2c_list[hist_idx + 2])
            dist += 0.5 * (dist_1 + dist_2)
        candidate_scores.append((hist_idx, dist / len(query_clip)))
 
    candidate_scores.sort(key=lambda x: x[1])
 
    memory = [0, 1, 2, 3]
    max_memory = memory_frames - temporal_context_size
    for start_idx, _ in candidate_scores:
        if len(memory) >= max_memory:
            break
        if start_idx not in memory:
            memory.extend(range(start_idx, start_idx + 4))
 
    return sorted(set(context_frames + memory))

这个函数对应 trainer/dataset/ar_camera_hunyuan_w_mem_dataset.py,说明 WorldPlay/WorldCompass 体系并不是简单截断历史,而是会优先回收几何上最相关的历史片段,从而减轻长时域 memory attenuation。

组件 B:autoregressive chunk rollout(代码可验证)

def ar_rollout(transformer, scheduler, latents, prompt_embeds, prompt_mask,
               vision_states, cond_latents, viewmats, Ks, action, num_chunks):
    kv_cache = cache_text_condition(transformer, prompt_embeds, prompt_mask, vision_states)
 
    for chunk_i in range(num_chunks):
        if chunk_i > 0:
            selected = build_memory_context(viewmats, current_chunk=chunk_i)
            context_latents = latents[:, :, selected]
            context_cond = cond_latents[:, :, selected]
            kv_cache = cache_vision_context(
                transformer,
                hidden_states=torch.cat([context_latents, context_cond], dim=1),
                viewmats=viewmats[:, selected],
                Ks=Ks[:, selected],
                action=action[:, selected],
                kv_cache=kv_cache,
            )
 
        start, end = chunk_i * 4, chunk_i * 4 + 4
        for t in scheduler.timesteps:
            latent_now = latents[:, :, start:end]
            cond_now = cond_latents[:, :, start:end]
            model_input = scheduler.scale_model_input(
                torch.cat([latent_now, cond_now], dim=1), t
            )
            noise_pred = transformer(
                hidden_states=model_input,
                timestep=torch.full((4,), t, device=latent_now.device),
                viewmats=viewmats[:, start:end],
                Ks=Ks[:, start:end],
                action=action[:, start:end],
                kv_cache=kv_cache,
                cache_vision=False,
            )[0]
            latents[:, :, start:end] = scheduler.step(noise_pred, t, latent_now)[0][:, :, -4:]
 
    return latents

这段逻辑对应 hyvideo/pipelines/worldplay_video_pipeline.py。它体现了 world model 的真正推理模式:先缓存文本,再逐 chunk 重建视觉 memory,并在每个 chunk 内执行 diffusion denoising。WorldCompass 的 clip-level RL 正是建立在这一 autoregressive rollout 机制之上。

3.3 Reward functions

WorldCompass 只奖励两个维度,但它们覆盖了交互式世界模型最核心的两个目标。

(1) Interaction Following Score

论文把动作分成 translationrotation 两部分。做法不是直接读取控制信号,而是先用 3D foundation model(文中引用 WorldMirrorDepth Anything V3)从生成视频中恢复 camera trajectory,再把连续轨迹映射回离散动作空间。

  • 对 rotation:比较相邻帧相对旋转是否超过阈值
  • 对 translation:由于不同场景下深度尺度不稳定,单一阈值不鲁棒,因此使用多组 ,只要任一阈值下匹配成功就算正确;
  • 最终将 translation accuracy 与 rotation accuracy 平均,作为 interaction following score。

这意味着 reward 不是“视频看起来像在动”这么模糊,而是“恢复出的 camera motion 是否与输入动作一致”,它更贴近 world model 真正需要对齐的 controllability。

def reward(video_path, gt_action, caption, camera_estimator="dav3",
           interval=1, update_latent_num=4):
    first_frame, images, predictions = process_video(
        video_path,
        interval=interval,
        last_frames=(update_latent_num * 4) - 3,
        camera_estimator=camera_estimator,
    )
 
    pred_pose = to_4x4_camera_pose(predictions.extrinsics)
    gt_action = expand_gt_action_to_frame_labels(gt_action)
 
    pred_002 = camera_pose_to_discrete_action(pred_pose, move_threshold=0.002, rotate_threshold=0.2)
    pred_005 = camera_pose_to_discrete_action(pred_pose, move_threshold=0.005, rotate_threshold=0.2)
    pred_010 = camera_pose_to_discrete_action(pred_pose, move_threshold=0.010, rotate_threshold=0.2)
 
    candidate_preds = [pred_002, pred_005, pred_010]
    candidate_accs = [
        (pred == gt_action).float().mean()
        for pred in candidate_preds
    ]
    best_idx = torch.tensor(candidate_accs).argmax().item()
    best_pred = candidate_preds[best_idx]
    action_acc = candidate_accs[best_idx]
 
    trans_acc = ((best_pred // 9) == (gt_action // 9)).float().mean()
    rotate_acc = ((best_pred % 9) == (gt_action % 9)).float().mean()
    fine_action_acc = 0.5 * (trans_acc + rotate_acc)
 
    hps_images = images[3::4]
    hpsv3_acc = hpsv3_text_reward(hps_images, caption).mean()
    hpsv3_quality_acc = hpsv3_quality_reward(hps_images).mean()
    first_quality = hpsv3_quality_reward([first_frame])[0]
    hpsv3_quality_drift_score = (-torch.abs(hpsv3_quality_reward(hps_images) - first_quality)).mean()
 
    return {
        "action_acc": action_acc,
        "fine_action_acc": fine_action_acc,
        "hpsv3_acc": hpsv3_acc,
        "hpsv3_quality_acc": hpsv3_quality_acc,
        "hpsv3_quality_drift_score": hpsv3_quality_drift_score,
    }

(2) Visual Quality Score

视觉质量部分使用 HPSv3。作者每隔 4 帧采样一张图像,计算这些帧的平均 HPSv3 分数。由于 HPSv3 同时对 text-image alignment 和 aesthetic quality 敏感,因此它在这里只是一个 proxy,但足够作为“别为了动作正确而把画面训崩”的 regularizer。

论文特别强调:两个 reward 必须同时使用。单独优化 IF score 会导致画面质量下降甚至 collapse;单独优化 VQ score 则会生成漂亮但几乎不动的静态视频。

def score_video(video_path, gt_action, caption, latent_num, camera_estimator="dav3"):
    first_frame, images, predictions = process_video(
        video_path,
        interval=1,
        last_frames=(latent_num * 4) - 3,
        camera_estimator=camera_estimator,
    )
 
    pred_pose = to_4x4_camera_pose(predictions.extrinsics)
    # 每个 chunk 对应 4 帧动作标签;实现里会先展开到逐帧标签,再丢掉第 1 帧,
    # 这样每个 chunk 都用“最近 5 帧 pose -> 当前 4 帧 action”来打分,
    # 相邻 chunk 因而共享 1 帧边界 pose。
    gt_action_expand = gt_action.repeat_interleave(4, dim=1)[:, 4:]
    action_scores, hps_scores = [], []
    quality_scores, drift_scores = [], []
    first_quality = hpsv3_quality_reward([first_frame])[0]
 
    for chunk_idx in range(latent_num - 1):
        chunk_end = ((chunk_idx + 1) * 4) + 1
        recent_pose_window = pred_pose[:chunk_end][-5:]
        pred_002 = camera_pose_to_discrete_action(recent_pose_window, move_threshold=0.002, rotate_threshold=0.2)
        pred_005 = camera_pose_to_discrete_action(recent_pose_window, move_threshold=0.005, rotate_threshold=0.2)
        pred_010 = camera_pose_to_discrete_action(recent_pose_window, move_threshold=0.010, rotate_threshold=0.2)
        gt_start = chunk_idx * 4
        gt_end = gt_start + 4
        chunk_gt = gt_action_expand[:, gt_start:gt_end]
        chunk_action_acc = max(
            (pred_002 == chunk_gt).float().mean(),
            (pred_005 == chunk_gt).float().mean(),
            (pred_010 == chunk_gt).float().mean(),
        )
        action_scores.append(chunk_action_acc)
 
        frame = images[chunk_end - 1]
        hps_scores.append(hpsv3_text_reward([frame], caption)[0])
        q = hpsv3_quality_reward([frame])[0]
        quality_scores.append(q)
        drift_scores.append(-torch.abs(q - first_quality))
 
    return {
        "action_acc": action_scores,
        "hps_acc": hps_scores,
        "hps_quality_acc": quality_scores,
        "hps_drift_score": drift_scores[-4:],
    }

这里我按源码语义把 score_video() 解释成“最近 5 帧 pose 窗口映射到当前 chunk 的 4 帧动作”。也就是说,第 k 个 chunk 的动作评估会取截至该 chunk 末尾的 5 帧相机位姿,再分别用 0.002 / 0.005 / 0.010 三组平移阈值离散化,最后取最优精度;对应 GT 则来自展开后的逐帧动作标签切片 gt_action_expand[:, k*4:(k+1)*4]。这比前一版伪代码更接近 reward_function.py 的实际索引语义,但仍然是为可读性做过整理的高层表达。

3.4 RL objective:negative-aware fine-tuning

作者没有采用 FlowGRPO 的同噪声 SDE rollout,而是采用受 DiffusionNFT - Online Diffusion Reinforcement with Forward Process 启发的 negative-aware fine-tuning。核心步骤如下:

先对每个 reward 维度做组内标准化,得到 advantage:

再将两种 advantage 线性组合并裁剪,得到第 个 rollout 的 optimality probability:

最后构造 negative-aware loss:

其中

直觉上, 会靠近“更优样本”方向,而 则把模型推离“较差样本”方向; 越大,正向项权重越大。论文还特别说明:他们没有使用 DiffusionNFT 中的 KL regularization,而是依靠更小的 learning rate 与 old policy 的 EMA 更新抑制过优化。

组件 C:WorldCompass 训练主循环(对齐 world_compass_in_train_pipeline.py

def train(self):
    setup_rng_and_scheduler()
    maybe_resume_from_checkpoint()
    self.ema_generator = EMA_FSDP_schedule(
        self.transformer,
        min_decay=0.2,
        max_decay=0.9,
        step_decay=0.01,
        ckpt_decay=0.9,
    )
 
    for step in range(self.init_steps + 1, self.training_args.max_train_steps + 1):
        if step == 1:
            self._eval(step - 1)
            if self.training_args.eval_only:
                return
 
        # 源码文本里这里还会额外执行 `self._eval(step - 1)`;
        # 为了突出训练主干,这里省略那条频繁评估路径,只保留主优化流程。
        training_batch = TrainingBatch()
        training_batch.current_timestep = step
        training_batch.current_vsa_sparsity = compute_vsa_sparsity(step)
        training_batch = self.train_one_step(training_batch)
        log_to_wandb(training_batch)
 
        if step % self.training_args.checkpointing_steps == 0:
            self._eval(step - 1)
            with self.ema_generator.apply_ckpt_shadow_to_model(self.transformer):
                save_checkpoint(...)
 
 
def train_one_step(self, training_batch):
    training_batch = self._prepare_training(training_batch)
    training_batch = self._get_next_batch(training_batch)
    training_batch = self._sample_reference_model(training_batch)
    training_batch = self._prepare_grpo_inputs(training_batch)
    training_batch = self._nft_forward_and_compute_loss(training_batch)
    self._clear_cache(training_batch)
    return training_batch

组件 D:reference rollout + reward(对齐 _sample_reference_model

@torch.no_grad()
def _sample_reference_model(self, training_batch):
    with self.ema_generator.apply_policy_shadow_to_model(self.transformer):
        selected_chunk_id = choose_chunk_id(
            current_step=training_batch.current_timestep,
            min_chunk_id=self.training_args.min_chunk_id,
            max_chunk_id=self.training_args.max_chunk_id,
            strategy=self.training_args.chunk_selection_strategy,
        )
 
        noise = make_group_noise(shape=(1, c, t, h, w))
        context_video = None
        if selected_chunk_id > 1:
            context_latents = self._sample_model_ode(training_batch, selected_chunk_id, noise)
            context_video = decode_context_latents(context_latents)
 
        kv_cache = self._create_sample_kv_cache()
        kv_cache = cache_text_branch(kv_cache, training_batch)
 
        if selected_chunk_id > 1:
            kv_cache, selected_frame_indices = self._build_kv_cache_from_previous_chunks(
                training_batch, context_latents, kv_cache, selected_chunk_id
            )
 
        for sample_idx in range(self.training_args.grpo_generation_num):
            new_noise = sample_specific_noise(sample_idx)
            latents_curr = build_sample_latents(context_latents, new_noise, sample_idx)
            start_chunk_id = max(0, selected_chunk_id - 1)
 
            for chunk_i in range(start_chunk_id, selected_chunk_id):
                for sigma in sigma_schedule(self.training_args.sampling_steps):
                    input_dict = build_sampling_inputs(
                        training_batch,
                        latents_curr,
                        kv_cache,
                        chunk_i=chunk_i,
                        selected_chunk_id=selected_chunk_id,
                    )
                    model_pred = self.transformer(txt_branch=False, input_dict=input_dict)[0]
                    latents_curr, pred_original = self.flux_step(model_pred, sigma, latents_curr)
 
            video = decode_video(pred_original)
            video_path = save_temp_video(video, context_video)
            reward_info = self.reward_model.reward(
                video_path,
                gt_camera_pose=training_batch.w2c,
                gt_action=current_chunk_action_tail(training_batch.action, update_latent_num=self.training_args.single_chunk_size),
                caption=training_batch.prompt,
                interval=1,
                update_latent_num=self.training_args.single_chunk_size,
            )
            stash_sample_kwargs(training_batch, sample_idx, pred_original, reward_info)
 
    return training_batch

组件 E:GRPO 风格样本筛选与 NFT 损失(对齐 _prepare_grpo_inputs + _nft_forward_and_compute_loss

def _prepare_grpo_inputs(self, training_batch):
    rewards = collect_rewards_from_sample_kwargs(training_batch)
    advantages = normalize_rewards(
        rewards,
        std_type=self.training_args.std_type,
        action_reward_type=self.training_args.action_reward_type,
    )
 
    overall_reward = (
        self.training_args.action_reward_weight * advantages["action"]
        + self.training_args.hpsv3_reward_weight * advantages["hpsv3"]
        + self.training_args.hpsv3_quality_reward_weight * advantages["hpsv3_quality"]
        + self.training_args.hpsv3_quality_drift_reward_weight * advantages["hpsv3_drift"]
    )
 
    sorted_indices = torch.argsort(overall_reward)
    per_rank = (self.training_args.bestofn // 2) // self.gpu_group.world_size
    top_indices = sorted_indices[-per_rank:]
    bottom_indices = sorted_indices[:per_rank]
    selected_indices = torch.cat([top_indices, bottom_indices])
    training_batch.sample_kwargs["shuffled_sample_keys"] = shuffle(selected_indices)
    return training_batch
 
 
def _nft_forward_and_compute_loss(self, training_batch):
    for sample_idx, sample_key in enumerate(training_batch.sample_kwargs["shuffled_sample_keys"]):
        train_sigmas = random_subset_of_sigma_schedule(
            total_steps=self.training_args.sampling_steps,
            fraction=self.training_args.train_timestep_fraction,
        )
 
        for sigma in train_sigmas:
            noisy_latents = make_noisy_latents(training_batch.sample_kwargs[sample_key]["pred_latents"], sigma)
            input_dict = build_training_inputs(training_batch, sample_key, noisy_latents)
 
            with torch.no_grad():
                with self.ema_generator.apply_policy_shadow_to_model(self.transformer):
                    model_pred_old = self.transformer(txt_branch=False, input_dict=input_dict)[0]
 
            model_pred = self.transformer(txt_branch=False, input_dict=input_dict)[0]
            # 这里按“代码意图”表达 old policy / current policy 双路预测。
            # 当前源码文本里存在 `model_pred_old` 被后续赋值覆盖的可疑实现细节,
            # 因此此处伪代码不把那一覆盖行为当作算法本意来展示。
            negative_prediction = 2 * model_pred_old.detach() - model_pred
 
            advantage = clip_advantage(training_batch.sample_kwargs[sample_key], adv_clip_max=2.0)
            r = map_advantage_to_probability(advantage)
 
            positive_x0 = reconstruct_x0_from_model_prediction(model_pred, noisy_latents, sigma)
            negative_x0 = reconstruct_x0_from_model_prediction(negative_prediction, noisy_latents, sigma)
            x0_target = training_batch.sample_kwargs[sample_key]["pred_latents"]
            positive_loss = ((positive_x0 - x0_target) ** 2).mean()
            negative_loss = ((negative_x0 - x0_target) ** 2).mean()
            policy_loss = r * positive_loss + (1 - r) * negative_loss
            (policy_loss / self.training_args.gradient_accumulation_steps).backward()
 
        if (sample_idx + 1) % self.training_args.gradient_accumulation_steps == 0:
            clip_grad_norm_()
            self.optimizer.step()
            self.lr_scheduler.step()
            self.optimizer.zero_grad()
            self.ema_generator.update_ckpt_shadow(self.transformer)
            self.ema_generator.update_policy_shadow(self.transformer, training_batch.current_timestep)

这几段伪代码是对 worldcompass/ 目录下真实训练实现的高层代码级抽象。需要注意四点:1) train() 段刻意省略了源码文本里更频繁的 _eval(step - 1) 调用,只保留主优化路径与 checkpoint 前评估;2) _sample_reference_model() 的真实循环层次是 sample_idx -> chunk_i -> sigma,而历史 chunk 常由 _sample_model_ode() 与 KV cache 预先处理;3) 代码里的 bestofn=6 实际是“每个 rank 取本地高分一半 + 低分一半,再拼接并打乱”,不是纯 top-N;4) old policy 并不是独立模型类,而是 EMA_FSDP_schedule.policy_shadow 通过上下文管理器临时覆盖到当前 transformer 上,且源码当前文本里还存在 model_pred_old 可能被覆盖的实现瑕疵,因此这里展示的是更接近算法意图的抽象流程。

3.5 Efficient training strategy

WorldCompass 之所以可以在 64 张 H20 上 3 天内完成训练,与下列三项效率设计直接相关:

  1. Subset of diffusion timesteps:只对一部分采样步做训练,而不是遍历所有
  2. Best-of-N selection:每组 rollout 只选 top-3 与 bottom-3 样本用于训练;
  3. Progressive optimization over clip index:目标 clip index 从 1 到 循环推进。

这三项设计并非附属技巧,而是 world model RL 可行性的关键。因为长时域 autoregressive video 的 rollout、reward 和 denoising 成本都远高于 text RL;如果没有 aggressive 的 sample/timestep pruning,训练成本会非常高。

Figure 2 解读:这张曲线图展示 RL 训练过程中 Interaction Following 与 Visual Quality reward 的演化趋势。作者在固定复杂组合动作测试子集上评估,可以看到两类 reward 都在较少训练步数内显著提升,说明 clip-level rollout + 双 reward + negative-aware fine-tuning 的组合确实给出了稳定而有效的优化信号。

Figure 3 解读:这组复杂组合动作的定性对比说明 WorldCompass 改善最明显的地方不是“画面更锐利”,而是面对连续复合动作时的轨迹执行更正确、动作切换更及时。图中的对比支持了论文的一个核心论断:复杂动作场景下,world model 的主要瓶颈是 controllability,而非单纯感知质量。

Figure 4 解读:在基础动作序列下,WorldCompass 仍然能提升 action following 与视觉质量,但提升幅度没有复杂动作那么夸张。这与论文主结果一致:基础动作本来就相对容易,RL 后训练的优势更多体现在动作切换速度和长期稳定性上。

Figure 5 解读:附录案例 1 用同一套 W+A -> right turn 动作序列展示训练前后的几何路径差异。论文通过重建 3D scene 和相机轨迹来对比,强调 WorldCompass 不仅让视频“看起来更像在动”,还让运动轨迹与空间几何关系更一致。

Figure 6 解读:案例 2 进一步说明改进并非局限于单一场景。对于不同几何结构和纹理环境,RL 后模型仍能保持更好的 trajectory adherence,说明 reward 是在优化通用的 action-conditioned world modeling 能力,而不是针对个别 prompt 过拟合。

Figure 7 解读:案例 3 更强调 long-horizon 情况下的误差积累问题。基线模型随着时间推进更容易出现 motion drift,而经过 WorldCompass 训练后,这种 drift 被明显压制,说明 clip-level reward 确实在长期 rollout 中提供了更有辨识度的纠偏信号。

Figure 8 解读:案例 4 汇总了附录可视化的共同结论——WorldCompass 对“动作执行准确性”和“空间几何一致性”两者都有改善,而不是只优化其中一项。这恰好呼应了双 reward mutual regularization 的方法设计。

3.6 Code-to-paper mapping table

论文概念Source FileKey Class / Function说明
WorldPlay autoregressive backbonehyvideo/models/transformers/worldplay_1_5_transformer.pyHunyuanVideo_1_5_DiffusionTransformer推理侧主干 transformer
Action conditioninghyvideo/models/transformers/worldplay_1_5_transformer.pyadd_action_parameters, forward_vision, forward_bi在向量条件中注入 action
AR action-aware training modeltrainer/models/hyvideo/models/transformers/ar_action_hunyuanvideo_1_5_transformer.pyARHunyuanVideo_1_5_DiffusionTransformer, add_discrete_action_parameters, forward训练侧 action-aware backbone
Reconstituted context memory(推理侧)hyvideo/utils/retrieval_context.pyselect_aligned_memory_frames推理链路实际调用的 memory frame 检索函数,基于 FOV overlap 选择历史帧
Reconstituted context memory(训练侧)trainer/dataset/ar_camera_hunyuan_w_mem_dataset.pyselect_aligned_memory_frames数据集侧使用的同名 memory 选择逻辑,用于构造训练样本
Autoregressive rollout with memory cachehyvideo/pipelines/worldplay_video_pipeline.pyar_rolloutar_rollout 在每个 chunk 前重建 context,并通过 kv_cache / cache_txt / cache_vision 复用缓存
Bidirectional rollout with memory concathyvideo/pipelines/worldplay_video_pipeline.pybi_rolloutbi_rollout 也重建 memory context,但主要通过拼接 context latents 前向,并不走 ar_rollout 的 KV cache 流程
Pose to discrete action labelhyvideo/generate.pypose_to_input将轨迹/姿态转换为 action_one_label 离散标签
Pose string to frame timelinehyvideo/generate.pyparse_pose_string_to_actions把动作字符串解析为逐帧 action timeline,而非最终离散标签
Memory-aware AR training looptrainer/training/ar_hunyuan_mem_training_pipeline.py_prepare_ar_dit_inputs, _build_input_kwargs, _transformer_forward_and_compute_loss, train开源训练主循环
WorldCompass RL trainerworldcompass/fastvideo/training/world_compass_in_train_pipeline.pytrain, train_one_step, _sample_reference_model, _prepare_grpo_inputs, _nft_forward_and_compute_lossRL 后训练主流程:reference rollout、reward 汇总、best-of-N 样本筛选、NFT 损失优化
Old/reference policy EMAworldcompass/fastvideo/training/training_utils.pyEMA_FSDP_schedule, apply_policy_shadow_to_model, update_policy_shadowold policy 通过 policy_shadow 维护,ckpt_shadow 则用于评测/保存 checkpoint
WorldCompass rewardworldcompass/reward_function/reward_function.pyCompassReward.reward, CompassReward.score_video, process_video奖励由动作精度、细粒度动作精度、HPSv3 文本分、HPSv3 质量分、画质漂移分构成
WorldCompass launch configworldcompass/scripts/train_worldcompass.shshell args公开脚本默认值为 grpo_generation_num=12bestofn=6sampling_steps=40camera_estimator=dav3

4. Experimental Setup(实验设置)

4.1 Datasets 与任务设定

  • Base world modelsWorldPlay 的两个版本:HunyuanVideo-1.5-8BWan2.2-5B
  • 动作空间:8 个 basic actions,包括前/后/左/右平移,以及上/下/左/右旋转;这些 basic actions 还能组合成复杂 compositional actions。
  • 视频组织方式:每个 chunk 对应一个 16-frame video clip;训练中最大生成长度 clips,对应约 256 decoded frames。
  • 训练数据:4,000 张多样化图像及其 caption;动作序列由 basic actions 随机构造而成,并刻意加入 complex action sequence 以增加难度。
  • 测试协议:从 WorldPlay test set 中选取 600 个 case,分别在 basic action 与 combined action 两种控制设置上评测,并报告短/中/长三种生成长度。

4.2 Metrics

  • Acc:每 4 帧与对应 action condition 对齐计算 action-following accuracy;
  • HPSv3:每 4 帧采样一次后求平均,作为 visual quality proxy。

这个评估方式相当严格,因为它不是只看整段视频最终是否“大致正确”,而是逐段检查每个动作区间是否真的被执行。

4.3 Training config

论文给出的关键超参数如下:

项目设置
每步训练 group 数64
rollout group size 16
rollout sampling steps 40
timestep 子采样比例50%
旋转阈值
平移阈值集合
optimizerMuon
learning rate
old policy EMA factor从 0.4 线性退火到 0.8
硬件64 × H20 GPUs
训练时长3 天

此外,论文还说明其 reward 完全来自自动评估器,因此不需要额外人工标注,这是 post-training 成本可控的重要原因。

5. Experimental Results(实验结果)

5.1 Main benchmark results

论文的主结果见 Table 1。为便于复查,我把原文数字按场景完整抄录如下。

HorizonModelCombined AccCombined HPSv3Basic AccBasic HPSv3
Short-term (125 frames)HY-Video-1.521.74-1.0562.331.96
Short-term (125 frames)+ WorldCompass58.200.4268.623.77
Short-term (125 frames)Wan2.222.87-1.1058.281.83
Short-term (125 frames)+ WorldCompass49.810.2064.723.17
Mid-term (253 frames)HY-Video-1.519.73-0.1963.351.91
Mid-term (253 frames)+ WorldCompass55.010.3774.093.61
Mid-term (253 frames)Wan2.220.33-1.6757.941.63
Mid-term (253 frames)+ WorldCompass50.320.2763.873.37
Long-term (381 frames)HY-Video-1.519.70-0.3364.281.90
Long-term (381 frames)+ WorldCompass54.820.7376.563.72
Long-term (381 frames)Wan2.219.58-0.8055.591.91
Long-term (381 frames)+ WorldCompass42.920.5963.913.59

几个最关键的观察:

  • 复杂组合动作提升最大:HY-Video-1.5 在 long-term combined action 上从 19.70 提升到 54.82,几乎是从“基本不会做”到“多数情况下能做对”。
  • 基础动作也稳定提升:HY-Video-1.5 在 long-term basic action 上从 64.28 提升到 76.56;Wan2.2 也从 55.59 提升到 63.91
  • 视觉质量没有被 RL 牺牲,反而同步上升:所有主要设置下 HPSv3 都明显提升。

论文特别强调:10%~30% 的 action accuracy 往往意味着模型根本没理解动作,而 50%~60% 则更多是动作切换 latency 造成的错误。因此从 20%55% 的跃迁并不是小修小补,而是行为能力层面的变化。

5.2 Core ablation

Table 2 检验了方法中的关键设计:

RowRollout TypeIF ScoreVQ ScoreRL AlgorithmCombined AccCombined HPSv3Basic AccBasic HPSv3
0----19.70-0.3364.281.90
1clip-levelDiffusionNFT54.820.7376.563.72
2sample-levelDiffusionNFT12.450.1958.422.69
3clip-levelDiffusionNFT36.39-2.6767.60-1.83
4clip-levelDiffusionNFT11.511.0135.944.19
5clip-levelDanceGRPO20.020.5967.433.97

这些消融几乎逐条验证了论文的核心 claim:

  • clip-level rollout 是必要的:row 2 的 sample-level rollout 直接把 combined action accuracy 拉低到 12.45
  • 双 reward 缺一不可:只保留 IF score 会把 HPSv3 打到负值,只保留 VQ score 会导致动作几乎不执行;
  • DiffusionNFT-style negative-aware optimization 更适合:换成 DanceGRPO 后,combined action accuracy 只有 20.02,几乎回到基线水平。

5.3 Efficiency ablation

Table 3 评估效率设计:

Subset of TimestepBest-of-N SamplingAccHPSv3Iteration Time
54.820.731.00×
55.280.751.42×
54.680.782.26×

可以看到:子采样 timestep 与 Best-of-N 并没有显著伤害最终性能,却把 iteration time 从 2.26× 压到 1.00×。也就是说,这些“效率优化”在这里几乎是免费午餐。

5.4 论文给出的局限性

作者在附录 Limitation 中承认,目前仍缺乏可靠指标去直接评估:

  • long-form generation 里的 visual quality drift
  • 长时域 rollout 中的 spatial memory retention

因此现有 reward 不能直接惩罚长视频后段的视觉漂移和空间记忆退化。作者现在只能通过更保守的训练策略(更小学习率、更少 iteration)缓解这一问题。换句话说,WorldCompass 已经解决“动作跟随”这一主要矛盾,但对“长期视觉/空间稳定性”的 reward 设计仍然不充分。

5.5 Overall conclusions

总体来看,WorldCompass 的贡献不在于提出一个全新 backbone,而在于证明:world model 需要独立的 RL post-training recipe,而这个 recipe 必须围绕 autoregressive rollout、clip-level reward 和多目标约束来设计。

从实验结果看,它最显著地提升了复杂组合动作场景下的 controllability,同时还带来了视觉质量提升;从方法论上,它把 diffusion RL 从 image/video generation 进一步推进到更贴近 agent interaction 的 world modeling 场景。对于后续研究,一个自然的问题是:如果未来能设计针对 long-horizon geometry drift 与 memory collapse 的更强 reward,world model 的 RL 上限可能会更高。