ThinkGen: Generalized Thinking for Visual Generation

Authors: Siyu Jiao, Yiheng Lin, Yujie Zhong, Qi She, Wei Zhou, Xiaohan Lan, Zilong Huang, Fei Yu, Yingchen Yu, Yunqing Zhao, Yao Zhao, Yunchao Wei Affiliations: Beijing Jiaotong University, ByteDance arXiv: 2512.23568 Project Page: huggingface.co/JSYuuu/ThinkGen GitHub: jiaosiyuu/ThinkGen

1. Motivation(研究动机)

现有视觉生成中的 CoT(Chain-of-Thought)方法大多只在单一场景内有效,例如只服务于 reasoning generation,或者依赖手工设计的“分步骤生成/指令改写”机制;一旦迁移到 text rendering、image editing、reflection 等任务,往往要重新定制流程,泛化性很差。

本文想解决两个核心问题:

  1. 如何把 MLLM 的显式推理能力真正迁移到视觉生成里,而不是只在某一个 benchmark 上临时奏效;
  2. 如何稳定地把“长而冗余”的 CoT 表征交给 DiT 使用,避免推理文本噪声拖累生成质量。

作者进一步指出,若直接把“思考”和“画图”绑成一个端到端大模型,会带来训练复杂、奖励设计困难、GPU 显存占用高等问题。因此,ThinkGen 试图构建一个可泛化、可分治、可强化学习优化的视觉生成框架,在 text-to-image、text rendering、image editing、reasoning generation、reasoning editing、reflection 等多场景上统一工作。


2. Idea(核心思想)

ThinkGen 的核心思想是:把“思考”与“绘制”解耦。具体来说,先用一个 pretrained MLLM(Qwen3-VL-8B-Think)做 CoT reasoning,把用户意图重写成更适合图像生成的 instruction;再用一个 DiT(初始化自 OmniGen2-DiT-4B)根据这些 instruction 生成图像。

论文的关键创新有三点:

  1. VGI-refine(Visual Generation Instruction Refinement):只保留 </think> 之后真正用于生成的 instruction hidden states,并拼接 learnable Prepadding States,减轻冗余 CoT token 对 DiT 的干扰;
  2. SepGRPO:把 RL 分成两段,先固定 DiT 优化 MLLM,再固定 MLLM 优化 DiT;
  3. 多场景联合训练:把 semantic composition、reasoning generation、text rendering、image editing、reflection 放在统一框架下训练,让“会思考的生成”从特例变成通用能力。

相对已有方法,ThinkGen 的根本差异不在于“多加一个 prompt rewrite 模块”,而在于:它把 MLLM 的 CoT 作为统一的中间推理层,并用 separable RL 让 reasoning module 与 generation module 各自朝最合适的方向优化。


3. Method(方法)

3.1 整体框架

Figure 3 解读:这张图给出了 ThinkGen 的主结构。左侧是 MLLM,负责接收文本或参考图像并进行 autoregressive CoT reasoning;中间是 VGI-refine,把 </think> 后的 instruction hidden states 抽取出来,再与 learnable Prepadding States 拼接;右侧是 DiT,把 refined text condition、噪声 latent、以及可选的参考图像 latent 一起送入 joint attention 生成最终图像。图中最重要的信息是:作者没有把“思考”和“绘制”做成一个单体网络,而是明确拆成 reasoning 模块与 generation 模块,这也是后续 SepGRPO 能成立的前提。

ThinkGen 的前向流程可以概括为四步:

  1. MLLM 理解用户意图:输入可以是 caption、editing instruction,或者“参考图 + 文本要求”;
  2. 生成 CoT 与 rewrite instruction:MLLM 输出带 <think> ... </think> 的推理过程与最终 instruction;
  3. VGI-refine 过滤 CoT:截取 </think> 之后的 token hidden states,并拼接 个 learnable Prepadding States;
  4. DiT 生成图像:DiT 接收 refined text states,以及可选的参考图像 latent,执行 flow matching / diffusion denoising 得到图像。

作者使用的骨干分别是:

  • MLLM:Qwen3-VL-8B-Think
  • DiT:OmniGen2-DiT-4B
  • VAE:用于把参考图编码成 visual latent

3.2 MLLM:从 CoT 到生成指令

论文在 supervised pre-training 与 SepGRPO 两个阶段使用不同输入模板:

  • Stage 1–3(pseudo-CoT)
  • Stage 4–5(真实 CoT rollout)

这里 是统一 system prompt, 是 caption 或 editing instruction。前者用于在没有真实 reasoning annotation 的大规模生成数据上“伪造”一个空 think 段,从而让 DiT 学会消费 reasoning-style condition;后者用于 RL 阶段让 MLLM 真正探索对 DiT 更友好的 rewrite instruction。

开源代码中,这个过程对应 ThinkGen/pipelines/pipeline_thinkgen.py 里的 _apply_chat_template()_get_qwen3_prompt_embeds()

  • think=True 时,会先让 MLLM 生成完整 CoT;
  • 然后取 最后两层 hidden states 拼接成条件表示;
  • 再根据特殊 token 位置,裁掉 </think> 之前的部分,只把 instruction hidden states 留给 DiT。

3.3 VGI-refine:去冗余 + Prepadding States

论文指出,MLLM 的 CoT token 很长、很冗余,若把整段 hidden states 直接喂给 DiT,会显著拖累生成。于是作者提出 VGI-refine,两步完成:

  1. CUT:只保留 </think> 之后的 hidden states;
  2. Prepadding:在前面拼接 个 learnable Prepadding States,帮助调整 representation distribution,尤其改善 short prompt 场景。

这一点在论文 ablation 中非常关键:

  • 只用 CUT(而不是 ALL)时,GenEval / WISE / CVTG / ImgEdit 都更好;
  • 加入 Prepadding States 后,short-prompt benchmark 明显提升。

从公开实现看,VGI-refine 对应两部分代码:

  • find_next_token_index():找到 </think> 之后的切分点;
  • ThinkGenTransformer2DModel.forward():将 prepad_embed 与文本 hidden states 拼接。

需要注意一个代码-论文差异:论文补充材料写的是 ,但公开推理代码中的 prepad_embed 形状为 1 x 23 x 8192。这说明当前公开仓库与论文描述之间存在一个小的不一致,读代码时应以 checkpoint/实现为准。

3.4 DiT:文本条件、参考图条件与联合注意力

在 DiT 侧,ThinkGen 把三类信息统一进 transformer:

  • noisy latent(当前待去噪图像)
  • refined text condition(来自 MLLM)
  • optional reference image latents(image editing / in-context generation 时)

论文强调,作者最后选用的是简单 linear connector,而不是 MLP 或更复杂的 transformer connector。补充材料进一步说明,该 connector 将 Qwen3-VL-8B-Think 最后两层 hidden states 拼接后的 8192 维特征映射到 2520 维,以匹配 DiT 的条件输入需求;补充实验也显示,Linear 在 Stage1 上优于 MLP / Transformer connector。

公开代码中,DiT 的实现位于 ThinkGen/models/transformers/transformer_thinkgen.py,关键机制包括:

  • x_embedder:把 noisy latent patch 投影到 transformer hidden space;
  • ref_image_patch_embedder:把参考图 latent patch 投影到相同空间;
  • image_index_embedding:给不同参考图加上区分性 embedding;
  • context_refiner / noise_refiner / ref_image_refiner:分别细化 text / noise / reference-image token;
  • joint hidden states 拼接后进入主干 transformer blocks。

3.5 训练配方:3 个 supervised stage + 2 个 RL stage

Figure 4 解读:这张图展示了 ThinkGen 的五阶段训练流程。前三阶段是 supervised pre-training:先只训练 connector 完成 MLLM–DiT 对齐,再做 60M 规模的预训练,最后用 0.7M 高质量子集做高分辨率微调;后两阶段是 SepGRPO:先固定 DiT 做 MLLM-GRPO,让 MLLM 学会生成更适合 DiT 的指令;再固定 MLLM 做 DiT-GRPO,让 DiT 在这些指令上继续提升生成质量。图里最重要的信号是,作者并没有尝试“一次性联合优化所有参数”,而是把 credit assignment 拆开处理。

五个阶段分别是:

  1. Stage1 Alignment:只训练 connector,引入 learnable Prepadding States,对齐 MLLM 与 DiT;
  2. Stage2 Pre-training:训练全部 DiT 参数,使用约 60M 图像样本;
  3. Stage3 High-quality fine-tuning:用 0.7M 高质量数据做 1024 分辨率微调;
  4. Stage4 MLLM-GRPO:固定 DiT,对 MLLM 做多场景 GRPO;
  5. Stage5 DiT-GRPO:固定 MLLM,对 DiT 做 FlowGRPO 风格优化。

SepGRPO 训练场景共 5 类:

ScenarioDatasetRule Model
Semantic composition5K semantic promptsGenEval
Reasoning generation10K reasoning promptsHPSv3
Text rendering3K text rendering promptsWord Accuracy
Image editing3K editing samplesSigLIP2
Reflection3K reflection samplesNED

3.6 关键公式

(1) Rectified Flow / Flow Matching loss

其中

这对应 Stage1–3 中 DiT 的监督训练目标。

(2) Group-relative advantage

对于同一输入的 个 MLLM rollout,奖励归一化为:

(3) GRPO 目标

其中 clipped surrogate 项为:

这里 是当前策略与旧策略在 token 级别的概率比。

3.7 SepGRPO 的实现细节

补充材料给出的 RL 细节如下:

  • 训练分辨率:
  • rollout 去噪步数:20 steps
  • CFG:4,仅在前 60% steps 开启
  • (MLLM rollout 数)
  • (DiT rollout 数)
  • DiT-GRPO 只对前 60% denoising steps 回传梯度

作者显式采用 Denoising Reduction 来降低 RL 数据收集成本,这一点继承自 Flow-GRPO。

3.8 基于公开代码的伪代码

说明:截至 2026-03-11,公开仓库已提供 推理代码与模型权重,但未公开 SepGRPO 的完整训练脚本。因此下面的伪代码严格对应已发布实现;对于 Stage4/5 的训练流程,本文仅保留论文公式与文字说明,不伪造代码细节。

伪代码 1:CoT prompt 构造与 hidden-state 提取

# file: ThinkGen/pipelines/pipeline_thinkgen.py
 
def build_cot_condition(instruction, input_images, think):
    prompt = apply_chat_template([
        {"role": "system", "content": SYS_PROMPT},
        {"role": "user", "content": add_image_tokens(input_images) + instruction},
    ])
 
    if think:
        prompt = prompt + "<|im_start|>assistant\n<think>\n"
        generated = mllm.generate(
            prompt,
            output_hidden_states=True,
            return_dict_in_generate=True,
            max_new_tokens=4096,
        )
        hidden = concat_last_two_layers(generated.hidden_states)
        split_idx = find_next_token_index(generated.sequences, split_seq=[151668])
        cond_hidden = hidden[:, split_idx:]
        cond_ids = generated.sequences[:, split_idx:]
        output_text = decode_new_tokens(generated.sequences)
    else:
        prompt = prompt + "<|im_start|>assistant\n<think>\n</think>" + instruction
        tokens = tokenizer(prompt)
        hidden = mllm(tokens, output_hidden_states=True).hidden_states
        hidden = concat(hidden[-2], hidden[-1], dim=-1)
        split_idx = find_next_token_index(tokens.input_ids, split_seq=[151668])
        cond_hidden = hidden[:, split_idx:]
        cond_ids = tokens.input_ids[:, split_idx:]
        output_text = ""
 
    return cond_hidden, cond_ids, output_text

伪代码 2:VGI-refine(CUT + Prepadding)

# file: ThinkGen/models/transformers/transformer_thinkgen.py
 
def vgi_refine(text_hidden_states, text_attention_mask, prepad_embed, prepad_mask):
    bs = text_hidden_states.shape[0]
    prepad = repeat(prepad_embed, bs, axis=0)
    prepad_attn = repeat(prepad_mask, bs, axis=0)
 
    refined_states = concat([prepad, text_hidden_states], dim=1)
    refined_mask = concat([prepad_attn, text_attention_mask], dim=1)
 
    return refined_states, refined_mask

伪代码 3:DiT 采样与双条件 CFG 融合

# file: ThinkGen/pipelines/pipeline_thinkgen.py
 
def generate_with_dit(prompt_embeds, neg_prompt_embeds, input_images, steps, seed):
    ref_latents = encode_reference_images_with_vae(input_images)
    latents = randn(seed=seed)
    timesteps = scheduler.set_timesteps(steps)
 
    for i, t in enumerate(timesteps):
        pred = transformer(latents, t, prompt_embeds, ref_latents)
 
        if text_guidance_scale > 1.0 and image_guidance_scale > 1.0:
            pred_ref = transformer(latents, t, neg_prompt_embeds, ref_latents)
            pred_uncond = transformer(latents, t, neg_prompt_embeds, None)
            pred = pred_uncond \
                 + image_guidance_scale * (pred_ref - pred_uncond) \
                 + text_guidance_scale * (pred - pred_ref)
        elif text_guidance_scale > 1.0:
            pred_uncond = transformer(latents, t, neg_prompt_embeds, None)
            pred = pred_uncond + text_guidance_scale * (pred - pred_uncond)
 
        latents = scheduler.step(pred, t, latents)
 
    image = vae.decode(latents)
    return image

3.9 Code-to-paper mapping

Paper ConceptSource FileKey Class / Function
用户侧统一接口ThinkGen/model.pyThinkGen_Chat, generate_image()
CoT chat templateThinkGen/pipelines/pipeline_thinkgen.py_apply_chat_template()
只截取 </think> 之后的 hidden statesThinkGen/pipelines/pipeline_thinkgen.pyfind_next_token_index(), _get_qwen3_prompt_embeds()
VGI-refine 的 Prepadding StatesThinkGen/models/transformers/transformer_thinkgen.pyprepad_embed, prepad_mask, forward()
文本/噪声/参考图联合建模ThinkGen/models/transformers/transformer_thinkgen.pyimg_patch_embed_and_refine()
参考图编码ThinkGen/pipelines/pipeline_thinkgen.pyprepare_image(), encode_vae()
采样主循环与 CFG 融合ThinkGen/pipelines/pipeline_thinkgen.pyprocessing(), predict(), __call__()
SchedulerThinkGen/schedulers/scheduling_flow_match_euler_discrete.pyFlowMatchEulerDiscreteScheduler
可选 DPM-Solver++ThinkGen/schedulers/scheduling_dpmsolver_multistep.pyDPMSolverMultistepScheduler
SepGRPO 训练脚本公开仓库未找到截至 2026-03-11 未公开

4. Experimental Setup(实验设置)

4.1 训练数据

监督训练的数据构成如下:

  • Text-to-image:约 51M 样本(补充材料);
  • Text rendering:约 3M 样本;
  • Image editing:约 5M 样本;
  • In-context generation:约 0.2M 样本;
  • Stage3 高质量子集:0.7M;
  • 论文正文还给出一个更粗粒度表述:Stage2 训练语料总计约 60M 图像样本。

SepGRPO 使用的多场景数据为:

  • 5K semantic prompts
  • 10K reasoning prompts
  • 3K text rendering prompts
  • 3K editing samples
  • 3K reflection samples

4.2 评测 benchmark

  • WISEBench:1000 prompts,评估 reasoning generation
  • RISEBench:360 pairs,评估 reasoning editing
  • GenEval:553 prompts,评估 compositional T2I
  • DPG-Bench:1065 prompts,评估长文本生成能力
  • CVTG:2000 prompts,评估 text rendering
  • ImgEdit:737 pairs,评估 image editing

4.3 Baselines

论文比较了大量 generation-only 与 unified generation-understanding 模型,包括:

  • generation-only:SDXL, SD3-Medium / SD-3.5-large, FLUX.1-dev, PixArt-, Sana-1.6B, TextCrafter, Step1X-Edit 等
  • unified / multimodal:Janus-Pro, BLIP3-o-8B, BAGEL, OmniGen2, UniWorld, VILA-U, MetaQuery-XL 等
  • closed-source 参照:GPT-4o, Gemini-2.0

4.4 指标与训练配置

评测指标

  • WISEBench / RISEBench:各 benchmark 的官方 score
  • GenEval:Counting / Position / Overall
  • DPG:Global / Entity / Attribute / Relation / Overall
  • CVTG:Word Accuracy, NED
  • ImgEdit:多编辑类型人工/模型评分综合结果

模型与训练配置

  • MLLM:Qwen3-VL-8B-Think
  • DiT:OmniGen2-DiT-4B
  • 训练目标:Rectified Flow / Flow Matching
  • Stage1 / 2 / 3 超参数:
    • LR: / /
    • Batch size:512 / 1280 / 64
    • Steps:47K / 100K / 11K
    • Resolution: / /
  • RL 阶段:,20 denoising steps,

硬件说明:论文没有明确给出 GPU 型号、卡数或总训练时长,属于实现细节中的信息缺口。


5. Experimental Results(实验结果)

5.1 主结果

(1) Reasoning generation:WISEBench

  • ThinkGen:0.55
  • ThinkGen*:0.76
  • 最强开源对比(同类表内):BAGEL* 为 0.70,STAR 为 0.66
  • closed-source GPT-4o:0.80

结论:开启 CoT 后,ThinkGen 在 WISE 上提升 +0.21,并把开源方法上限推到接近 GPT-4o 的水平。

(2) Reasoning editing:RISEBench

  • ThinkGen:3.6
  • ThinkGen*:13.0
  • 最强开源对比:BAGEL* 为 11.9
  • Gemini-2.0:13.3
  • GPT-4o:28.9

结论:ThinkGen 在 reasoning editing 上提升非常明显(3.6 13.0),已经接近 Gemini-2.0,但距离 GPT-4o 仍有明显差距,尤其 logical reasoning 子项仍只有 1.1。

(3) Text-to-image / long prompt / text rendering

ModelGenEval OverallDPG OverallCVTG Acc.CVTG NED
BAGEL0.8285.070.350.65
OmniGen20.8083.570.520.77
ThinkGen0.8885.140.800.91
ThinkGen*0.8985.870.840.94

结论:ThinkGen 在 GenEval / DPG / CVTG 上都很强,尤其 text rendering 提升最显著,CVTG Accuracy 从 0.80 提到 0.84,远高于 BAGEL 的 0.35。

(4) Image editing:ImgEdit

  • ThinkGen:4.14
  • ThinkGen*:4.21
  • GPT-4o:4.20
  • OmniGen2:3.44
  • UniWorld:3.26

结论:ThinkGen* 在 ImgEdit 上整体达到表中最高分,略高于 GPT-4o。

Figure 2 解读:这张 teaser 直观展示了 ThinkGen 的任务覆盖面。最上方是 image generation,中间是 image editing,下面两行是 reasoning generation / reasoning editing / reflection。它想传达的重点不是“单个样例很漂亮”,而是 同一个框架可以在多种生成场景里都使用 think-driven instruction,这是本文相对特定场景 CoT 方法的真正卖点。

5.2 消融实验

(1) 训练阶段消融

StageGenEvalWISECVTG
Stage1 Alignment0.780.460.28
Stage2 Pre-training0.880.550.63
Stage3 H.Q. Tuning0.880.550.75
Stage4 MLLM-GRPO0.860.540.75
Stage4 MLLM-GRPO*0.860.760.79
Stage5 DiT-GRPO*0.890.760.84

关键信息:

  • Stage2 解决的是基础生成质量;
  • Stage3 解决的是高质量与 text rendering;
  • Stage4 真正带来 reasoning boost;
  • Stage5 再把 instruction-following 与 rendering 继续抬高。

(2) Prepadding States 有效性

  • GenEval:0.64 0.78
  • WISE:0.37 0.46
  • CVTG:0.24 0.28
  • ImgEdit:3.46 3.93
  • DPG:80.90 80.86(几乎持平)

说明 Prepadding States 主要提升的是 short-prompt / 对齐敏感任务,对 long-prompt DPG 几乎没有副作用。

(3) CUT vs ALL(VGI-refine)

  • CUT:GenEval 0.78 / WISE 0.46 / CVTG 0.28 / ImgEdit 3.93
  • ALL:GenEval 0.66 / WISE 0.31 / CVTG 0.18 / ImgEdit 3.43

这说明“把整个 CoT hidden states 全送进 DiT”是有害的,作者关于“CoT 冗余需要裁剪”的判断是成立的。

(4) 训练策略对比

  • Stage3 baseline:GenEval 0.88 / WISE 0.55 / CVTG 0.75
  • 10K reasoning data 直接做 SFT:0.85 / 0.58 / 0.67
  • 10K reasoning data 做 MLLM-GRPO:0.80 / 0.74 / 0.73
  • 24K multitask data 做 MLLM-GRPO:0.86 / 0.76 / 0.79

说明:提升 reasoning generation 的关键不是“多喂 reasoning 数据”,而是 MLLM-GRPO 这种 reward-driven 学习方式。

Figure 5 解读:这张图可视化了 MLLM-GRPO 的训练过程。横轴是训练步数,图中同时展示了 reward score、CoT length 和生成图像样例。可以看到随着训练推进,reward 稳步上升、CoT 长度增加,而图像质量也逐渐改善。它说明作者的 SepGRPO 并不是仅靠最终 benchmark 提升“碰巧有效”,而是在训练动态上也表现出越来越会思考、越来越会生成的趋势。

5.3 局限性

  1. 训练代码未公开:公开仓库目前主要是 inference code + weights,SepGRPO 的完整训练实现无法复现;
  2. 论文缺少硬件信息:没有明确 GPU 型号、卡数与训练时长;
  3. reasoning editing 仍未逼近 GPT-4o:尤其 logical reasoning 子项很弱;
  4. 代码/论文存在轻微不一致:例如 Prepadding States 的 值,论文补充材料写 25,公开代码看起来是 23。

5.4 总结

ThinkGen 最有价值的地方,不只是把 CoT 搬进视觉生成,而是提出了一个可泛化的“先想再画”范式

  • MLLM 负责 reasoning 与 instruction rewrite;
  • VGI-refine 负责把冗余 CoT 压缩成对 DiT 有用的条件;
  • SepGRPO 负责分别优化“想得对”和“画得好”。

从结果看,ThinkGen 在 reasoning generation、text rendering、image editing 上都很强;从方法看,它为后续把 RL、MLLM、DiT 组合到统一视觉生成框架里提供了一个非常清晰的模板。