概念界定
Logits 是模型输出的未归一化分数,输出头是把 hidden state 映射到任务输出空间的线性层。在语言模型中,LM Head 通常把最后一层 hidden state 映射到词表大小的 logits。
背景与问题
Transformer 内部 hidden state 的维度通常是 D,但语言模型最终需要在 V 个词表 token 中预测下一个 token。因此需要一个输出头把 [D] 维表示映射到 [V] 维 logits。
结构与机制
最后一层 hidden state:
h_t: [D]LM Head 权重:
W_vocab: [D, V]输出 logits:
z_t = h_t W_vocab
z_t: [V]再经过 softmax:
p_t = softmax(z_t)得到下一个 token 的概率分布。
直观解释
输出头可以理解为对每个候选 token 打分。logit 越高,经过 softmax 后该 token 的概率通常越高。
基本性质
- logits 不是概率,可以为负,也不要求总和为 1。
- softmax 把 logits 转换成概率分布。
- LM Head 的参数量通常与词表大小和 hidden size 成正比。
- 有些模型会把 input embedding 和 output head 权重绑定。
示例
语言模型预测下一个 token:
final hidden state: [B, T, D]
lm_head: [D, V]
logits: [B, T, V]训练时,对每个位置的 logits 和真实下一个 token 计算交叉熵。
常见误解
- 误解:logits 就是概率。
- 正确理解:logits 是未归一化分数,需要 softmax 转成概率。
- 误解:输出头只在分类模型中存在。
- 正确理解:语言模型也有输出头,只是输出空间是词表。
- 误解:logit 最大的 token 一定总会被输出。
- 正确理解:贪心解码会选最大 logit,但采样解码可能选择其他 token。