Hi, I was plotting the output of the entmax_bisect function at different levels of alpha
from matplotlib import pyplot as plt
x = torch.linspace(-1, 1, 10000).unsqueeze(-1)
x = torch.cat((x, torch.zeros_like(x)), dim=-1)
y = entmax_bisect(x, 10, dim=-1)[..., 0]
plt.plot(x.numpy(), y.numpy())
Here's what I'm getting

I'm expecting this to be a smooth function except for two points when it transitions to 0 or 1
Any explanation for this?