LogitsMetadata、裁剪状态与选择性算 logits#

执行模型前面的章节已经把 LogitsProcessorOutputSamplerReq.check_finished() 和 hidden states / routed experts 的旁路返回逐层讲开了,但如果没有一章把“哪些位置会继续进入 logits、logprobs 和 hidden states 处理链”单独讲清,execution model 仍然会少一个关键现实:系统并不会对 batch 里每一个位置都一视同仁地继续算下去。

真正决定“哪些位置值得继续处理”的,是:

  • LogitsMetadata
  • _get_pruned_states()
  • _get_hidden_states_to_store()

这章的价值就在于把这条选择性处理链讲透。只要这层读稳了,前面那些看起来像“系统为什么有时只看最后一个 token、有时又要保留更多位置”的现象,就会重新变成一条可理解的执行优化主线。

先把这层放回 execution model 的真正问题#

很多读者会自然假设:forward 出来多少 hidden states,后半段就继续处理多少 hidden states。源码并不是这样工作的。execution model 到这里首先要解决的不是“怎样算更多”,而是“哪些位置值得继续算”。下面这张图的作用,就是把这条工作集选择链画出来:

flowchart LR
    Hidden["forward hidden_states"] --> Meta["LogitsMetadata"]
    Meta --> Prune["_get_pruned_states()"]
    Prune --> Logits["positions for logits/logprobs"]
    Prune --> Sample["sample_indices"]
    Prune --> InputLP["input_logprob_indices"]
    Prune --> Store["_get_hidden_states_to_store()"]
    Store --> Return["optional returned hidden states"]

图里最重要的一点是:执行后半段首先解决的是“哪些位置值得继续处理”,而不是立刻对整段状态做统一操作。

LogitsMetadata 不是参数袋,而是后半段的执行计划#

如果只看字段,LogitsMetadata 很容易被读成普通配置对象。更稳的理解是,它其实在回答执行后半段最关键的几个问题:

  • 当前 batch 处在什么 forward_mode
  • 是否要返回 hidden states,返回哪种 hidden-state 形态
  • 是否需要 input logprobs
  • 对哪些 extend token 真正需要继续算 logits / logprobs
  • 本轮是不是 prefill-only 人格

这意味着它更像“logits 阶段的执行计划”,而不是单纯承接 ForwardBatch 字段的容器。

from_forward_batch(...) 值得最先读#

如果只挑一个入口,最推荐的就是 LogitsMetadata.from_forward_batch(...)。因为这一步真正做了“把 ForwardBatch 的执行人格编译成后续 logits 处理计划”的工作。尤其重要的是,它会结合:

  • forward_mode
  • return_logprob
  • extend_seq_lens_cpu
  • extend_logprob_start_lens_cpu

推导出:

  • extend_return_logprob
  • extend_logprob_pruned_lens_cpu

这说明“要不要算 logprobs”根本不是一个单独布尔值,而是要和本轮到底是 extend 还是 decode、各序列从哪里开始需要 logprobs 一起看。也正因此,LogitsMetadata 更像一次局部执行编译,而不是简单字段拷贝。

_get_pruned_states() 是这条链真正的核心#

如果要只深入一个函数,最值得看的其实就是它。因为它直接回答:

  • 哪些 hidden states 会继续走 logits 计算
  • 哪些位置会被拿来 sample
  • 哪些位置会被拿来算 input logprobs

从结构上看,它至少处理三类情况。

第一类是 decode / idle / target verify / draft extend v2。
这类路径最直接,系统基本不会进一步裁到“每序列最后一个 token”,因为当前输入本来就已经是 decode 或特殊验证语义。

第二类是 extend 且不返回 input logprobs。
这时系统只取每个序列最后一个有效位置,也就是“只保留真正参与 next-token 预测的那些点”。这正解释了为什么很多普通生成路径里,后半段看起来像只关心最后一个 token。

第三类是 extend 且要返回 input logprobs。
这时就复杂很多了。系统不仅要保留 pruned_states,还要同时维护:

  • sample_indices
  • input_logprob_indices
  • token_to_seq_idx

也就是说,一旦打开 input logprob 返回,execution model 后半段就不再只是“拿最后一个 token”,而是要保留一段更复杂的有效窗口。

sample_indicesinput_logprob_indices 回答的是不同问题#

这两个名字特别容易被混成“反正都是索引”。更稳的理解是:

  • sample_indices 决定哪些位置会产出真正用于 next-token 选择的 sampled logits
  • input_logprob_indices 决定哪些位置需要计算 input logprobs

它们虽然都来自同一段 pruned_states,但服务的是完全不同的后续语义。技术书在这里最该做的事情,就是把这种“名字相近但问题不同”的变量明确拆开,否则读者在 logprob 和 hidden-state 返回逻辑里很容易迷路。

pruned_states 本身就是一次非常现实的工程优化#

从实现上看,prune 的意义很直接:

  • 不再对整段无用 hidden states 做完整 logits 计算
  • 只保留后续真正需要的那些位置

这同时带来两类收益:

  • 节省计算
  • 降低后续 logprobs / hidden states 返回的内存和拷贝开销

因此这条链真正讲的不是“优雅裁剪”,而是 execution model 怎么为性能和返回语义之间做工作集折中。

_get_hidden_states_to_store() 决定了最终到底带出多少 hidden states#

如果 _get_pruned_states() 决定“哪些位置继续算”,那么 _get_hidden_states_to_store() 决定的就是:

  • 最后真的要把哪些 hidden states 带出 execution layer

这一步主要受 capture_hidden_mode 控制。

capture_hidden_mode.is_full() 时,系统会尽量保留更完整的 hidden states 视图,甚至在存在 aux_hidden_states 时先拼接。
capture_hidden_mode.is_last() 时,系统只保留最后位置的 hidden states;如果 sample_indices 已经存在,又会进一步只取 sample 位置。

这意味着即使 return_hidden_states 打开了,最终带出去的也不一定是整段 full states,而可能只是最后位置或 sample 位置。也就是说,“要不要返回 hidden states”和“返回多少 hidden states”是两层不同的判断。

hidden_states_before_norm 的优先级特别值得注意#

源码里有个非常成熟的设计信号:

  • 如果 hidden_states_to_store_before_norm 不为空
  • 最终优先返回它

这说明 hidden states 返回语义不是“谁更方便就返谁”,而是在显式偏好更接近某种稳定边界的状态口径。这类细节特别适合技术书点出来,因为它能帮助读者理解:返回的隐藏状态本身也是被认真选择过的语义,而不是执行层顺手留出来的任何一份 tensor。

is_prefill_only 让 logits 处理不再只服务 next-token 生成#

LogitsMetadata 里的 is_prefill_only 值得单独提,是因为它提醒读者:这条 logits 处理链并不永远只服务于 next-token generation。在某些人格里,它更像:

  • multi-item scoring
  • embedding / pooling 前的评分视图
  • 非生成型 logits 消费路径

这让 5.15 embedding、pooling 与 rerank 执行路径 和本章之间形成了很自然的回扣:logits 层本身也会随着请求人格切换而改写工作集和返回策略。

这章和 5.8、5.16 的边界#

这三章各自回答的是不同问题:

也就是说,5.8 讲“处理之后怎样落回 request 和 surface”,5.16 讲“会额外带出什么”,而 5.18 讲“从一开始就选了哪些位置继续处理”。把这三层分开,execution model 的后半段才不会混成一团。

最容易出现的三种误判#

第一,误以为系统总是对整段 hidden states 都算 logits。
实际上 _get_pruned_states() 会先做非常激进的工作集裁剪。

第二,误以为 sample_indicesinput_logprob_indices 是同一组位置。
它们分别服务 sampled logits 与 input logprobs,语义完全不同。

第三,误以为打开 hidden states 返回就一定返 full states。
真正返回多少,还要看 capture_hidden_mode 和 before-norm 优先逻辑。

真正顺着源码读这条“选择性处理”链时,推荐顺序#

建议按这个顺序:

  1. LogitsMetadata.from_forward_batch(...)
  2. _get_pruned_states()
  3. _get_hidden_states_to_store()
  4. 最后再回到 5.8LogitsProcessorOutput 与 output processor

这样读,你会先理解“选哪些位置”,再理解“这些位置的结果怎样继续被使用和带出”,最不容易在 logprob / hidden states 逻辑里迷路。

小结#

这一章真正补齐的,是 execution model 里一条此前很关键、却仍然偏隐性的主线:系统不会对所有位置一视同仁地处理 logits 和 hidden states,LogitsMetadata 和两步裁剪函数共同定义了后半段真正的工作集。

读懂这层之后,execution model 对“为什么某些位置被继续算、某些位置被跳过”也就有了正式解释。