PyTorch 如何定位 NaN?
通过 detect_anomaly、assert 等方法定位深度学习中的 NaN 异常,以及一些常见的问题原因
有关内容在网上其实有一些比较系统的文章,但质量参差不齐,优质内容不易定位。这里结合我自己的经验再总结一下。
注:本文讨论的是单精度训练。半精度训练会有更多的
NaN成因,但不在本文讨论之列。
检查方法
detect_anomaly
常见用法:
# loss = model(X)
with torch.autograd.detect_anomaly():
loss.backward()
在计算时出现 NaN 时即时报错。其给出的错误位置可能不甚准确(尤其是反向传播才出现 NaN 时),但其错误信息往往会对你的调试工作有所启发。调试反向传播出现的 NaN 时,思考的重点应该放在调试信息中的 backward_fn 上。
assert
我们已经知道,NaN 是一种具有传染性的错误。NaN 调试的关键问题是定位第一个 NaN 出现的位置。
在 PyTorch 中,检查 NaN 的函数为 torch.isnan(T)。于是我们可以构造如下断言:
assert not torch.any(torch.isnan(T))
这是一个需要判断力的方法:将这个断言加在你认为有可能出现 NaN 的步骤之后。你的判断可以帮助你少加断言、快速找到第一现场。
不过没找到第一现场也不用担心,尽管这个现场已经漂移,你也可以利用调试器的堆栈功能快速搜寻真正的事发现场。
NaN 的可能原因
讲完三板斧总得讲讲 NaN 的成因,要不然就是光有方法没有理论。尤其是 assert,要求调试者非常充分且熟练地掌握 NaN 的可能成因。
梯度爆炸
梯度爆炸,或者梯度消失都可能导致 NaN。这个问题往往会被 detect_anomaly 捕获,但真正定位到问题却难上加难。相对来说,重新推导一遍自己的理论模型、寻找可能导致梯度爆炸的计算显得更有针对性。
计算不合法
这也是 NaN 最常见的成因。毕竟大多数的网络,尤其是复现、组合别人的网络结构一般不会碰到梯度爆炸的问题,而 NaN 大多出现于 loss 计算的部分,诞生于某个小小的不合法计算,然后污染它参与计算的所有结果,最后在你的 loss 值上表现出来。
常见套路:
- $ \log x, x \leq 0 $
- $ c/0 $
- $ \sqrt 0 $,正向传播正常,反向传播无法计算梯度
$abs(0)$ 似乎在反向传播时不会产生
NaN
尚有其他的一些情况我自己没遇到过,网上可能会有补充。
这种问题运气好的话会被 detect_anomaly 直接找到,但通常是找到一个漂移了不少距离的位置。推荐用 assert 的办法,尤其是 自己写了 loss 时,在关键位置放几个 assert 守门,总归是没错的。
注意,绝大多数时候,inf 也是不合常理的存在。因此你可能也需要同时寻找 inf:
assert not torch.any(torch.isnan(T) + torch.isinf(T))
脏数据
NaN 的次常见成因。顾名思义,出现 NaN 仅仅是因为送进模型的数据里含有 NaN。通常来说直接读图片不会出现 NaN,往往是大意地处理数据后会出现这种情况。
随便举个例子。
mask = mask / mask.max()
# serialize mask
这句话看起来没问题,把 mask 的值域缩放到 float32[0, 1]。在遇到一张纯黑(max=0)的 mask 之前,很难注意到这里存在这样的隐患。
毕竟我们总是假设,标注的 mask 至少是有一个非零值的。但不管怎么说,此时我们犯了"除零"的错误。这个 mask 会变成携带 NaN 的脏数据输入模型,并在计算 loss 时将 loss 结果污染。如果程序没有及时终止,在仅仅一次反向传播之后,你的模型参数将变为 NaN,其一切推导也将得出 NaN。
检查 NaN 的一般步骤
- (静态地)检查非法计算 & 检查数据
- 开启异常检测
- 给模型的直接输出结果和最终 loss 加
assert - 通过经验、猜测、反推等方法逐步把
assert加到之前的步骤,直到触发的assert帮你找到了不合法计算 - 若计算 loss 的过程中没有发现问题,且总是触发反向传播异常,那可以考虑从理论上检查梯度爆炸和梯度消失
Hook this up to your favourite commenting platform — Giscus, Disqus, or your own.