Skip to content

entmax_bisect bugs with fp16/bf16 #30

@bpopeters

Description

@bpopeters

It would be great if entmax worked with torch.float16 and torch.bfloat16. Unfortunately, it currently does not. There are bugs for both bisection and the exact algorithm. Here I'll document a numerical stability problem that exists for the bisection-based algorithm for both torch.float16 and torch.bfloat16 (don't believe the propaganda that says that bf16 is a drop-in solution for float32).

Let's say you have a 32-bit vector of logits whose largest element is sufficiently negative.

a = torch.zeros(128, device="cuda").fill_(-5)  # torch.float32
a[0] = 0
a -= 1000

With alpha=1.5, the correct output for this vector is a one-hot distribution peaked on index 0. We get this behavior with both entmax.entmax15 and entmax.entmax_bisect.

p1 = entmax.entmax15(a)
p2 = entmax.entmax_bisect(a, alpha=1.5)

p1[0] == p2[0] == 1  # True

Ok, great. But what happens if we use torch.float16?

b = a.to(torch.float16)

p3 = entmax.entmax_bisect(b, alpha=1.5)
p3.isnan().all()  # True

and what about torch.bfloat16?

c = a.to(torch.bfloat16)

p4 = entmax.entmax_bisect(c, alpha=1.5)
p4.isnan().all()  # True

Well that's not good! (solution after this commercial break)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions