如何手写softmax函数防止数值溢出?
当我手写cross-entropy的时候,发现有时候竟然会出现error?整个数学计算过程没问题,主要问题就在于上溢出和下溢出,即当遇到极大或是极小的logits的时候,如果直接用公式按照exp的方式去进行softmax的话就会出现数值溢出的情况。为了解决这个问题,首先需要做的就是减去最大值,即:
logits = logits - torch.max(logits, 1)[0][:, None]
原理可以看这个链接:
https://zhuanlan.zhihu.com/p/29376573
但是我减去最大值之后还是会出现溢出,这个时候经过检查发现softmax后还是出现了0的情况,那再经过log函数之后就会变成负无穷,此时不要用手写的:
torch.log(F.softmax(logits, dim=-1))
而是直接使用torch自带的log_softmax,其做了一定的容错控制:
F.log_softmax(logits, dim=-1)
或者在使用log的时候加一个很小的数,防止出现0的情况。
版权声明:本文为weixin_42988382原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。