-
Notifications
You must be signed in to change notification settings - Fork 47
Description
It would be great if entmax worked with torch.float16 and torch.bfloat16. Unfortunately, it currently does not. There are bugs for both bisection and the exact algorithm. Here I'll document a numerical stability problem that exists for the bisection-based algorithm for both torch.float16 and torch.bfloat16 (don't believe the propaganda that says that bf16 is a drop-in solution for float32).
Let's say you have a 32-bit vector of logits whose largest element is sufficiently negative.
a = torch.zeros(128, device="cuda").fill_(-5) # torch.float32
a[0] = 0
a -= 1000
With alpha=1.5, the correct output for this vector is a one-hot distribution peaked on index 0. We get this behavior with both entmax.entmax15 and entmax.entmax_bisect.
p1 = entmax.entmax15(a)
p2 = entmax.entmax_bisect(a, alpha=1.5)
p1[0] == p2[0] == 1 # True
Ok, great. But what happens if we use torch.float16?
b = a.to(torch.float16)
p3 = entmax.entmax_bisect(b, alpha=1.5)
p3.isnan().all() # True
and what about torch.bfloat16?
c = a.to(torch.bfloat16)
p4 = entmax.entmax_bisect(c, alpha=1.5)
p4.isnan().all() # True
Well that's not good! (solution after this commercial break)