-
Notifications
You must be signed in to change notification settings - Fork 47
Open
Description
When all inputs to entmax are -inf, it fails with
RuntimeError Traceback (most recent call last)
<ipython-input-404-217bd9c1ced2> in <module>
1 from entmax import entmax15
2 logits = torch.ones(10) * float('-inf')
----> 3 entmax15(logits)
~/.virtualenvs/sparseref/lib/python3.7/site-packages/entmax/activations.py in entmax15(X, dim, k)
254 """
255
--> 256 return Entmax15Function.apply(X, dim, k)
257
258
~/.virtualenvs/sparseref/lib/python3.7/site-packages/entmax/activations.py in forward(cls, ctx, X, dim, k)
176 X = X / 2 # divide by 2 to solve actual Entmax
177
--> 178 tau_star, _ = _entmax_threshold_and_support(X, dim=dim, k=k)
179
180 Y = torch.clamp(X - tau_star, min=0) ** 2
~/.virtualenvs/sparseref/lib/python3.7/site-packages/entmax/activations.py in _entmax_threshold_and_support(X, dim, k)
129
130 support_size = (tau <= Xsrt).sum(dim).unsqueeze(dim)
--> 131 tau_star = tau.gather(dim, support_size - 1)
132
133 if k is not None and k < X.shape[dim]:
RuntimeError: index -1 is out of bounds for dimension 0 with size 10
A minimal snippet to reproduce this behavior is
from entmax import entmax15
logits = torch.ones(10) * float('-inf')
entmax15(logits)
For reference, torch.softmax will return a tensor of nan's. This is certainly a corner case, but sometimes padding may create -inf-only inputs and it's easier to deal with nan's later.
[This is possibly related to #9 ]
Metadata
Metadata
Assignees
Labels
No labels