Pytorch 如何定位 NaN?

orange and black bug on green leaf
A bug @ringane

有关内容在网上其实有一些比较系统的文章, 但灌水很多相对不太好找. 这里结合我自己的经验再总结一下.

注:本文讨论的是单精度训练。半精度训练会有更多的NaN成因,但不在本文讨论之列。

检查方法

detect_anomaly

常见用法:

# loss = model(X)
with torch.autograd.detect_anomaly():
    loss.backward()

文档:detect_anomaly

在计算时出现NaN时即时报错. 其给出的错误位置可能不甚准确(尤其是反向传播才出现NaN时), 但其错误信息往往会对你的调试工作有所启发。调试反向传播出现的NaN时,思考的重点应该放在调试信息中的backward_fn上。

assert

我们已经知道,NaN是一种具有传染性的错误。NaN调试的关键问题是定位第一个NaN出现的位置。

在 pytorch 中, 检查NaN的函数为torch.isnan(T). 于是我们可以构造如下断言:

assert not torch.any(torch.isnan(T))

这是一个需要判断力的方法:将这个断言加在你认为有可能出现NaN的步骤之后. 你的判断可以帮助你少加断言、快速找到第一现场。

不过没找到第一现场也不用担心,尽管这个现场已经漂移, 你也可以利用调试器的堆栈功能快速搜寻真正的事发现场.

NaN的可能原因

讲完三板斧总得讲讲NaN的成因, 要不然就是光有方法没有理论(x 尤其是 assert, 要求调试者非常充分且熟练地掌握NaN的可能成因.

梯度爆炸

梯度爆炸, 或者梯度消失都可能导致NaN. 这个问题往往会被detect_anomaly捕获, 但真正定位到问题却难上加难. 相对来说, 重新推导一遍自己的理论模型、寻找可能导致梯度爆炸的计算显得更有针对性.

计算不合法

这也是NaN最常见的成因. 毕竟大多数的网络, 尤其是复现、组合别人的网络结构一般不会碰到梯度爆炸的问题, 而NaN大多出现于 loss 计算的部分, 诞生于某个小小的不合法计算, 然后污染它参与计算的所有结果, 最后在你的 loss 值上表现出来.

常见套路:

  • $ \log x, x \leq 0 $
  • $ c/0 $
  • $ \sqrt 0 $, 正向传播正常,反向传播无法计算梯度

$abs(0)$似乎在反向传播时不会产生NaN

尚有其他的一些情况我自己没遇到过, 网上可能会有补充

这种问题运气好的话会被detect_anomaly直接找到, 但通常是找到一个漂移了亿点点的位置. 推荐用assert的办法, 尤其是 自己写了loss时, 在关键位置放几个assert守门, 总归是没错的.

注意, 绝大多数时候, inf也是不合常理的存在. 因此你可能也需要同时寻找inf:

assert not torch.any(torch.isnan(T) + torch.isinf(T))

脏数据

NaN的次常见成因. 顾名思义, 出现NaN仅仅是因为送进模型的数据里含有NaN. 通常来说直接读图片不会出现NaN, 往往是大意地处理数据后会出现这种情况.

随便举个例子.

mask = mask / mask.max()
# serialize mask

这句话看起来没问题, 把 mask 的值域缩放到float32[0, 1]. 在遇到一张纯黑(max=0)的 mask 之前,很难注意到这里存在这样的隐患 :(

毕竟我们总是假设,标注的 mask 至少是有一个非零值的。但不管怎么说, 此时我们犯了”除零”的错误. 这个 mask 会变成携带NaN的脏数据输入模型, 并在计算 loss 时将 loss 结果污染. 如果程序没有及时终止, 在仅仅一次反向传播之后, 你的模型参数将变为NaN, 其一切推导也将得出NaN

检查NaN的一般步骤

  1. (静态地)检查非法计算 & 检查数据
  2. 开启异常检测
  3. 给模型的直接输出结果和最终 loss 加assert
  4. 通过经验、猜测、反推等方法逐步把assert加到之前的步骤, 直到触发的assert帮你找到了不合法计算
  5. 若计算 loss 的过程中没有发现问题, 且总是触发反向传播异常, 那可以考虑从理论上检查梯度爆炸和梯度消失