Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions entmax/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.cuda.amp import custom_fwd, custom_bwd


def _make_ix_like(X, dim):
Expand Down Expand Up @@ -144,6 +145,7 @@ def _entmax_threshold_and_support(X, dim=-1, k=None):

class SparsemaxFunction(Function):
@classmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(cls, ctx, X, dim=-1, k=None):
ctx.dim = dim
max_val, _ = X.max(dim=dim, keepdim=True)
Expand All @@ -154,6 +156,7 @@ def forward(cls, ctx, X, dim=-1, k=None):
return output, supp_size

@classmethod
@custom_bwd
def backward(cls, ctx, grad_output, supp):
supp_size, output = ctx.saved_tensors
dim = ctx.dim
Expand All @@ -168,6 +171,7 @@ def backward(cls, ctx, grad_output, supp):

class Entmax15Function(Function):
@classmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(cls, ctx, X, dim=0, k=None):
ctx.dim = dim

Expand All @@ -182,6 +186,7 @@ def forward(cls, ctx, X, dim=0, k=None):
return Y, supp_size

@classmethod
@custom_bwd
def backward(cls, ctx, dY, supp):
Y, = ctx.saved_tensors
gppr = Y.sqrt() # = 1 / g'' (Y)
Expand Down
5 changes: 5 additions & 0 deletions entmax/root_finding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.cuda.amp import custom_fwd, custom_bwd


class EntmaxBisectFunction(Function):
Expand All @@ -27,6 +28,7 @@ def _p(cls, X, alpha):
return cls._gp_inv(torch.clamp(X, min=0), alpha)

@classmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(cls, ctx, X, alpha=1.5, dim=-1, n_iter=50, ensure_sum_one=True):

if not isinstance(alpha, torch.Tensor):
Expand Down Expand Up @@ -71,6 +73,7 @@ def forward(cls, ctx, X, alpha=1.5, dim=-1, n_iter=50, ensure_sum_one=True):
return p_m

@classmethod
@custom_bwd
def backward(cls, ctx, dY):
Y, = ctx.saved_tensors

Expand Down Expand Up @@ -118,12 +121,14 @@ def _p(cls, x, alpha):
return torch.clamp(x, min=0)

@classmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(cls, ctx, X, dim=-1, n_iter=50, ensure_sum_one=True):
return super().forward(
ctx, X, alpha=2, dim=dim, n_iter=50, ensure_sum_one=True
)

@classmethod
@custom_bwd
def backward(cls, ctx, dY):
Y, = ctx.saved_tensors
gppr = (Y > 0).to(dtype=dY.dtype)
Expand Down
44 changes: 44 additions & 0 deletions entmax/test_amp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import pytest
import torch
from functools import partial

from entmax import entmax15, sparsemax, entmax_bisect

torch.manual_seed(42)

def make_negatives(dtype, max_pow):
negatives = []
for i in range(2, max_pow + 1):
negative = torch.randn(128, dtype=dtype, device="cuda") - 10 ** i
negative[0] += 5
negatives.append(negative)
return negatives


if torch.cuda.is_available():

mappings = [entmax15, sparsemax, partial(entmax_bisect, alpha=1.5), partial(entmax_bisect, alpha=2)]

long_bf16 = [
torch.randn(32000, dtype=torch.bfloat16, device="cuda")
for _ in range(5)
]
negatives_bf16 = make_negatives(torch.bfloat16, 7)

long_fp16 = [
torch.randn(32000, dtype=torch.float16, device="cuda")
for _ in range(5)
]
negatives_fp16 = make_negatives(torch.float16, 4)

@pytest.mark.parametrize("Xs", (long_bf16, negatives_bf16, long_fp16, negatives_fp16))
@pytest.mark.parametrize("func", mappings)
def test_probs_close(Xs, func):
dtype = Xs[0].dtype

full_precision_probs = [func(X.to(torch.float32), dim=-1) for X in Xs]
_Xs = [X.to(dtype) for X in Xs]
with torch.autocast(device_type="cuda", dtype=dtype):
for _X, fpp in zip(_Xs, full_precision_probs):
probs = func(_X, dim=-1)
assert torch.allclose(probs, fpp)