Pytorch 如何定位 NaN?
有关内容在网上其实有一些比较系统的文章, 但灌水很多相对不太好找. 这里结合我自己的经验再总结一下.
注:本文讨论的是单精度训练。半精度训练会有更多的
NaN
成因,但不在本文讨论之列。
检查方法
detect_anomaly
常见用法:
# loss = model(X)
with torch.autograd.detect_anomaly():
loss.backward()
在计算时出现NaN
时即时报错. 其给出的错误位置可能不甚准确(尤其是反向传播才出现NaN时), 但其错误信息往往会对你的调试工作有所启发。调试反向传播出现的NaN
时,思考的重点应该放在调试信息中的backward_fn
上。
assert
我们已经知道,NaN
是一种具有传染性的错误。NaN
调试的关键问题是定位第一个NaN
出现的位置。
在 pytorch 中, 检查NaN
的函数为torch.isnan(T)
. 于是我们可以构造如下断言:
assert not torch.any(torch.isnan(T))
这是一个需要判断力的方法:将这个断言加在你认为有可能出现NaN
的步骤之后. 你的判断可以帮助你少加断言、快速找到第一现场。
不过没找到第一现场也不用担心,尽管这个现场已经漂移, 你也可以利用调试器的堆栈功能快速搜寻真正的事发现场.
NaN的可能原因
讲完三板斧总得讲讲NaN
的成因, 要不然就是光有方法没有理论(x 尤其是 assert, 要求调试者非常充分且熟练地掌握NaN
的可能成因.
梯度爆炸
梯度爆炸, 或者梯度消失都可能导致NaN
. 这个问题往往会被detect_anomaly捕获, 但真正定位到问题却难上加难. 相对来说, 重新推导一遍自己的理论模型、寻找可能导致梯度爆炸的计算显得更有针对性.
计算不合法
这也是NaN
最常见的成因. 毕竟大多数的网络, 尤其是复现、组合别人的网络结构一般不会碰到梯度爆炸的问题, 而NaN
大多出现于 loss 计算的部分, 诞生于某个小小的不合法计算, 然后污染它参与计算的所有结果, 最后在你的 loss 值上表现出来.
常见套路:
- $ \log x, x \leq 0 $
- $ c/0 $
- $ \sqrt 0 $, 正向传播正常,反向传播无法计算梯度
$abs(0)$似乎在反向传播时不会产生
NaN
尚有其他的一些情况我自己没遇到过, 网上可能会有补充
这种问题运气好的话会被detect_anomaly直接找到, 但通常是找到一个漂移了亿点点的位置. 推荐用assert的办法, 尤其是 自己写了loss时, 在关键位置放几个assert
守门, 总归是没错的.
注意, 绝大多数时候, inf
也是不合常理的存在. 因此你可能也需要同时寻找inf
:
assert not torch.any(torch.isnan(T) + torch.isinf(T))
脏数据
NaN
的次常见成因. 顾名思义, 出现NaN
仅仅是因为送进模型的数据里含有NaN
. 通常来说直接读图片不会出现NaN
, 往往是大意地处理数据后会出现这种情况.
随便举个例子.
mask = mask / mask.max()
# serialize mask
这句话看起来没问题, 把 mask 的值域缩放到float32[0, 1]
. 在遇到一张纯黑(max=0)的 mask 之前,很难注意到这里存在这样的隐患 :(
毕竟我们总是假设,标注的 mask 至少是有一个非零值的。但不管怎么说, 此时我们犯了”除零”的错误. 这个 mask 会变成携带NaN
的脏数据输入模型, 并在计算 loss 时将 loss 结果污染. 如果程序没有及时终止, 在仅仅一次反向传播之后, 你的模型参数将变为NaN
, 其一切推导也将得出NaN
。
检查NaN的一般步骤
- (静态地)检查非法计算 & 检查数据
- 开启异常检测
- 给模型的直接输出结果和最终 loss 加
assert
- 通过经验、猜测、反推等方法逐步把
assert
加到之前的步骤, 直到触发的assert
帮你找到了不合法计算 - 若计算 loss 的过程中没有发现问题, 且总是触发反向传播异常, 那可以考虑从理论上检查梯度爆炸和梯度消失