概念界定
梯度裁剪是在梯度过大时对其进行缩放或截断的稳定训练技术。它常用于防止梯度爆炸和训练初期 loss spike。
背景与问题
深层网络或长序列训练中,某些 batch 可能产生异常大的梯度。如果直接用这些梯度更新参数,模型可能出现数值不稳定甚至发散。梯度裁剪用于限制单步更新的最大强度。
定义与记号
常见方式是按全局范数裁剪:
if ||g|| > c:
g = c · g / ||g||其中:
g是所有参数梯度组成的整体梯度。c是裁剪阈值。||g||通常是 L2 norm。
直观解释
梯度裁剪像给参数更新加了安全阀。正常梯度不受影响,异常大的梯度会被缩小到可接受范围。
基本性质
- 梯度裁剪不改变 loss 定义,只改变更新过程。
- 裁剪阈值过小会抑制正常学习。
- 裁剪阈值过大则无法有效防止异常更新。
- 大模型训练中常配合 mixed precision、AdamW 和 warmup 使用。
示例
训练日志中如果出现突然的 loss spike 或 grad norm 爆炸,可以检查是否需要梯度裁剪、调整学习率或排查数据异常。
常见误解
- 误解:梯度裁剪能修复所有训练不稳定。
- 正确理解:它只能限制梯度幅度,无法解决数据错误、学习率过大或数值溢出等根因。
- 误解:梯度裁剪越强越安全。
- 正确理解:过强裁剪会让有效学习信号变弱。