From e06db4bda611d745ee757800e13a9593a94430af Mon Sep 17 00:00:00 2001 From: Ben Peters Date: Wed, 25 Oct 2023 17:05:41 +0100 Subject: [PATCH 01/15] compute exact algo with full precision when using amp --- entmax/activations.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/entmax/activations.py b/entmax/activations.py index 53a3574..2b11c26 100644 --- a/entmax/activations.py +++ b/entmax/activations.py @@ -13,6 +13,7 @@ import torch import torch.nn as nn from torch.autograd import Function +from torch.cuda.amp import custom_fwd, custom_bwd def _make_ix_like(X, dim): @@ -144,6 +145,7 @@ def _entmax_threshold_and_support(X, dim=-1, k=None): class SparsemaxFunction(Function): @classmethod + @custom_fwd(cast_inputs=torch.float32) def forward(cls, ctx, X, dim=-1, k=None): ctx.dim = dim max_val, _ = X.max(dim=dim, keepdim=True) @@ -154,6 +156,7 @@ def forward(cls, ctx, X, dim=-1, k=None): return output @classmethod + @custom_bwd def backward(cls, ctx, grad_output): supp_size, output = ctx.saved_tensors dim = ctx.dim @@ -168,6 +171,7 @@ def backward(cls, ctx, grad_output): class Entmax15Function(Function): @classmethod + @custom_fwd(cast_inputs=torch.float32) def forward(cls, ctx, X, dim=0, k=None): ctx.dim = dim @@ -182,6 +186,7 @@ def forward(cls, ctx, X, dim=0, k=None): return Y @classmethod + @custom_bwd def backward(cls, ctx, dY): Y, = ctx.saved_tensors gppr = Y.sqrt() # = 1 / g'' (Y) From eed1fe3ed25f92bce4ac6ca3300e28c5fb561eab Mon Sep 17 00:00:00 2001 From: Ben Peters Date: Thu, 26 Oct 2023 11:33:11 +0100 Subject: [PATCH 02/15] add sums-to-one test for autocasting --- entmax/test_losses.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/entmax/test_losses.py b/entmax/test_losses.py index 0b792a9..340e2d2 100644 --- a/entmax/test_losses.py +++ b/entmax/test_losses.py @@ -60,10 +60,3 @@ def test_index_ignored(Loss): loss_noignore = Loss(reduction="sum", ignore_index=-100) assert loss_ignore(x, y) < loss_noignore(x, y) - - -if __name__ == "__main__": - test_sparsemax_loss() - test_entmax_loss() - test_sparsemax_bisect_loss() - test_entmax_bisect_loss() From 241ab7cd178bad9389785e94797ef3f22d15c190 Mon Sep 17 00:00:00 2001 From: Ben Peters Date: Thu, 26 Oct 2023 11:34:36 +0100 Subject: [PATCH 03/15] woops! now add amp test for real --- entmax/test_amp.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 entmax/test_amp.py diff --git a/entmax/test_amp.py b/entmax/test_amp.py new file mode 100644 index 0000000..e43df05 --- /dev/null +++ b/entmax/test_amp.py @@ -0,0 +1,31 @@ +import pytest +import torch +from functools import partial + +from entmax import entmax15, sparsemax, entmax_bisect + + +# These tests only work on cuda, so the first test will fail if you do not have it + + +def test_cuda_available(): + assert torch.cuda.is_available() + + +if torch.cuda.is_available(): + + mappings = [entmax15, sparsemax, partial(entmax_bisect, alpha=1.5), partial(entmax_bisect, alpha=2)] + + # make data + Xs = [ + torch.randn(1000, dtype=torch.float32, device="cuda") + for _ in range(5) + ] + + + @pytest.mark.parameterize("func", mappings) + @pytest.mark.parameterize("dtype", (torch.bfloat16, torch.float16)) + def test_sum_one(func, dtype): + with torch.autocast(device_type="cuda", dtype=dtype): + for X in Xs: + assert func(X).sum(-1).eq(1) From 0efbdceaf464105c34c952802379212beea2d823 Mon Sep 17 00:00:00 2001 From: Ben Peters Date: Thu, 26 Oct 2023 11:35:30 +0100 Subject: [PATCH 04/15] update spelling to parametrize --- entmax/test_amp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/entmax/test_amp.py b/entmax/test_amp.py index e43df05..016aea4 100644 --- a/entmax/test_amp.py +++ b/entmax/test_amp.py @@ -23,8 +23,8 @@ def test_cuda_available(): ] - @pytest.mark.parameterize("func", mappings) - @pytest.mark.parameterize("dtype", (torch.bfloat16, torch.float16)) + @pytest.mark.parametrize("func", mappings) + @pytest.mark.parametrize("dtype", (torch.bfloat16, torch.float16)) def test_sum_one(func, dtype): with torch.autocast(device_type="cuda", dtype=dtype): for X in Xs: From 1e8889c6d09a7368ab8dd7856a3c41d4f474a39c Mon Sep 17 00:00:00 2001 From: Ben Peters Date: Thu, 26 Oct 2023 10:49:11 +0000 Subject: [PATCH 05/15] refine tests to use all_close --- entmax/test_amp.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/entmax/test_amp.py b/entmax/test_amp.py index 016aea4..1f44bdd 100644 --- a/entmax/test_amp.py +++ b/entmax/test_amp.py @@ -26,6 +26,9 @@ def test_cuda_available(): @pytest.mark.parametrize("func", mappings) @pytest.mark.parametrize("dtype", (torch.bfloat16, torch.float16)) def test_sum_one(func, dtype): + _Xs = [_X.to(dtype) for _X in Xs] with torch.autocast(device_type="cuda", dtype=dtype): - for X in Xs: - assert func(X).sum(-1).eq(1) + for _X in _Xs: + scores = func(_X) + prob_mass = scores.sum(-1) + assert torch.allclose(prob_mass, torch.tensor([1.0], device="cuda")) From c56c1847fbbf05b9cf186e13b728e6fab2a14b06 Mon Sep 17 00:00:00 2001 From: Ben Peters Date: Thu, 26 Oct 2023 11:59:29 +0100 Subject: [PATCH 06/15] add probs closeness test for amp --- entmax/test_amp.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/entmax/test_amp.py b/entmax/test_amp.py index 1f44bdd..0d5c2fc 100644 --- a/entmax/test_amp.py +++ b/entmax/test_amp.py @@ -22,13 +22,20 @@ def test_cuda_available(): for _ in range(5) ] - @pytest.mark.parametrize("func", mappings) @pytest.mark.parametrize("dtype", (torch.bfloat16, torch.float16)) def test_sum_one(func, dtype): _Xs = [_X.to(dtype) for _X in Xs] with torch.autocast(device_type="cuda", dtype=dtype): for _X in _Xs: - scores = func(_X) + scores = func(_X, dim=-1) prob_mass = scores.sum(-1) assert torch.allclose(prob_mass, torch.tensor([1.0], device="cuda")) + + def test_probs_close(func, dtype): + full_precision_probs = [func(_X, dim=-1) for _X in Xs] + _Xs = [_X.to(dtype) for _X in Xs] + with torch.autocast(device_type="cuda", dtype=dtype): + for _X in _Xs: + probs = func(_X, dim=-1) + assert torch.allclose(probs, full_precision_probs) From fac2cefb1496185dcee84613637b99c16a0be516 Mon Sep 17 00:00:00 2001 From: Ben Peters Date: Thu, 26 Oct 2023 12:00:36 +0100 Subject: [PATCH 07/15] add missing decorators --- entmax/test_amp.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/entmax/test_amp.py b/entmax/test_amp.py index 0d5c2fc..6c45a22 100644 --- a/entmax/test_amp.py +++ b/entmax/test_amp.py @@ -32,6 +32,8 @@ def test_sum_one(func, dtype): prob_mass = scores.sum(-1) assert torch.allclose(prob_mass, torch.tensor([1.0], device="cuda")) + @pytest.mark.parametrize("func", mappings) + @pytest.mark.parametrize("dtype", (torch.bfloat16, torch.float16)) def test_probs_close(func, dtype): full_precision_probs = [func(_X, dim=-1) for _X in Xs] _Xs = [_X.to(dtype) for _X in Xs] From 04e343bc2689d98cd9b1d466e78ee20ba1bd222f Mon Sep 17 00:00:00 2001 From: Ben Peters Date: Thu, 26 Oct 2023 12:02:36 +0100 Subject: [PATCH 08/15] fix test --- entmax/test_amp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/entmax/test_amp.py b/entmax/test_amp.py index 6c45a22..d11a475 100644 --- a/entmax/test_amp.py +++ b/entmax/test_amp.py @@ -38,6 +38,6 @@ def test_probs_close(func, dtype): full_precision_probs = [func(_X, dim=-1) for _X in Xs] _Xs = [_X.to(dtype) for _X in Xs] with torch.autocast(device_type="cuda", dtype=dtype): - for _X in _Xs: + for _X, fpp in zip(_Xs, full_precision_probs): probs = func(_X, dim=-1) - assert torch.allclose(probs, full_precision_probs) + assert torch.allclose(probs, fpp) From 50da1a006c0e59eef3d6e12d4f45b662472411af Mon Sep 17 00:00:00 2001 From: Ben Peters Date: Thu, 26 Oct 2023 11:13:48 +0000 Subject: [PATCH 09/15] update prob test --- entmax/test_amp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/entmax/test_amp.py b/entmax/test_amp.py index d11a475..c6f6abf 100644 --- a/entmax/test_amp.py +++ b/entmax/test_amp.py @@ -35,7 +35,7 @@ def test_sum_one(func, dtype): @pytest.mark.parametrize("func", mappings) @pytest.mark.parametrize("dtype", (torch.bfloat16, torch.float16)) def test_probs_close(func, dtype): - full_precision_probs = [func(_X, dim=-1) for _X in Xs] + full_precision_probs = [func(_X.to(dtype).to(torch.float32), dim=-1) for _X in Xs] _Xs = [_X.to(dtype) for _X in Xs] with torch.autocast(device_type="cuda", dtype=dtype): for _X, fpp in zip(_Xs, full_precision_probs): From 07e08de2631195e43381da2391caa7880ceae3c2 Mon Sep 17 00:00:00 2001 From: Ben Peters Date: Thu, 26 Oct 2023 15:43:01 +0100 Subject: [PATCH 10/15] test on other kinds of tensors --- entmax/test_amp.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/entmax/test_amp.py b/entmax/test_amp.py index c6f6abf..f7aac60 100644 --- a/entmax/test_amp.py +++ b/entmax/test_amp.py @@ -17,26 +17,34 @@ def test_cuda_available(): mappings = [entmax15, sparsemax, partial(entmax_bisect, alpha=1.5), partial(entmax_bisect, alpha=2)] # make data - Xs = [ - torch.randn(1000, dtype=torch.float32, device="cuda") + long_vecs = [ + torch.randn(32000, dtype=torch.float32, device="cuda") for _ in range(5) ] + negatives = [] + for i in range(2, 6): + negative = torch.randn(128, dtype=torch.float32, device="cuda") - 10 ** i + negative[0] += 5 + negatives.append(negative) + + @pytest.mark.parametrize("Xs", (long_vecs, negatives)) @pytest.mark.parametrize("func", mappings) @pytest.mark.parametrize("dtype", (torch.bfloat16, torch.float16)) - def test_sum_one(func, dtype): - _Xs = [_X.to(dtype) for _X in Xs] + def test_sum_one(Xs, func, dtype): + _Xs = [X.to(dtype) for X in Xs] with torch.autocast(device_type="cuda", dtype=dtype): for _X in _Xs: scores = func(_X, dim=-1) prob_mass = scores.sum(-1) assert torch.allclose(prob_mass, torch.tensor([1.0], device="cuda")) + @pytest.mark.parametrize("Xs", (long_vecs, negatives)) @pytest.mark.parametrize("func", mappings) @pytest.mark.parametrize("dtype", (torch.bfloat16, torch.float16)) - def test_probs_close(func, dtype): - full_precision_probs = [func(_X.to(dtype).to(torch.float32), dim=-1) for _X in Xs] - _Xs = [_X.to(dtype) for _X in Xs] + def test_probs_close(Xs, func, dtype): + full_precision_probs = [func(X.to(dtype).to(torch.float32), dim=-1) for X in Xs] + _Xs = [X.to(dtype) for X in Xs] with torch.autocast(device_type="cuda", dtype=dtype): for _X, fpp in zip(_Xs, full_precision_probs): probs = func(_X, dim=-1) From 95423c3ce7a98ff6d1a45efe1f5ab45e74f1edb7 Mon Sep 17 00:00:00 2001 From: Ben Peters Date: Thu, 26 Oct 2023 16:06:36 +0100 Subject: [PATCH 11/15] fix(?) bugs with allclose and nans --- entmax/test_amp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/entmax/test_amp.py b/entmax/test_amp.py index f7aac60..5acd2b9 100644 --- a/entmax/test_amp.py +++ b/entmax/test_amp.py @@ -37,7 +37,7 @@ def test_sum_one(Xs, func, dtype): for _X in _Xs: scores = func(_X, dim=-1) prob_mass = scores.sum(-1) - assert torch.allclose(prob_mass, torch.tensor([1.0], device="cuda")) + assert torch.allclose(prob_mass, torch.tensor([1.0], dtype=dtype, device="cuda")) @pytest.mark.parametrize("Xs", (long_vecs, negatives)) @pytest.mark.parametrize("func", mappings) @@ -48,4 +48,4 @@ def test_probs_close(Xs, func, dtype): with torch.autocast(device_type="cuda", dtype=dtype): for _X, fpp in zip(_Xs, full_precision_probs): probs = func(_X, dim=-1) - assert torch.allclose(probs, fpp) + assert torch.allclose(probs, fpp.to(dtype)) From 08ef3ed91db4e7afae18328d862585d5c4a06748 Mon Sep 17 00:00:00 2001 From: Ben Peters Date: Thu, 26 Oct 2023 16:16:45 +0100 Subject: [PATCH 12/15] be more careful about dtype in tests --- entmax/test_amp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/entmax/test_amp.py b/entmax/test_amp.py index 5acd2b9..a08c49a 100644 --- a/entmax/test_amp.py +++ b/entmax/test_amp.py @@ -37,7 +37,7 @@ def test_sum_one(Xs, func, dtype): for _X in _Xs: scores = func(_X, dim=-1) prob_mass = scores.sum(-1) - assert torch.allclose(prob_mass, torch.tensor([1.0], dtype=dtype, device="cuda")) + assert torch.allclose(prob_mass, torch.tensor([1.0], dtype=prob_mass.dtype, device="cuda")) @pytest.mark.parametrize("Xs", (long_vecs, negatives)) @pytest.mark.parametrize("func", mappings) @@ -48,4 +48,4 @@ def test_probs_close(Xs, func, dtype): with torch.autocast(device_type="cuda", dtype=dtype): for _X, fpp in zip(_Xs, full_precision_probs): probs = func(_X, dim=-1) - assert torch.allclose(probs, fpp.to(dtype)) + assert torch.allclose(probs, fpp.to(probs.dtype)) From 461b2b709bcd650ab607a4cd12b20c86445c0d61 Mon Sep 17 00:00:00 2001 From: Ben Peters Date: Thu, 26 Oct 2023 16:18:54 +0100 Subject: [PATCH 13/15] autocast to float32 with bisection --- entmax/root_finding.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/entmax/root_finding.py b/entmax/root_finding.py index 6de6f70..88ee5ff 100644 --- a/entmax/root_finding.py +++ b/entmax/root_finding.py @@ -10,6 +10,7 @@ import torch import torch.nn as nn from torch.autograd import Function +from torch.cuda.amp import custom_fwd, custom_bwd class EntmaxBisectFunction(Function): @@ -26,6 +27,7 @@ def _p(cls, X, alpha): return cls._gp_inv(torch.clamp(X, min=0), alpha) @classmethod + @custom_fwd(cast_inputs=torch.float32) def forward(cls, ctx, X, alpha=1.5, dim=-1, n_iter=50, ensure_sum_one=True): if not isinstance(alpha, torch.Tensor): @@ -69,6 +71,7 @@ def forward(cls, ctx, X, alpha=1.5, dim=-1, n_iter=50, ensure_sum_one=True): return p_m @classmethod + @custom_bwd def backward(cls, ctx, dY): Y, = ctx.saved_tensors @@ -116,12 +119,14 @@ def _p(cls, x, alpha): return torch.clamp(x, min=0) @classmethod + @custom_fwd(cast_inputs=torch.float32) def forward(cls, ctx, X, dim=-1, n_iter=50, ensure_sum_one=True): return super().forward( ctx, X, alpha=2, dim=dim, n_iter=50, ensure_sum_one=True ) @classmethod + @custom_bwd def backward(cls, ctx, dY): Y, = ctx.saved_tensors gppr = (Y > 0).to(dtype=dY.dtype) From 573a47dc3c22250ec875bec88c50234ed31d47fd Mon Sep 17 00:00:00 2001 From: Ben Peters Date: Fri, 27 Oct 2023 14:39:56 +0100 Subject: [PATCH 14/15] refactor tests --- entmax/test_amp.py | 48 +++++++++++++++++++--------------------------- 1 file changed, 20 insertions(+), 28 deletions(-) diff --git a/entmax/test_amp.py b/entmax/test_amp.py index a08c49a..1d30a9e 100644 --- a/entmax/test_amp.py +++ b/entmax/test_amp.py @@ -5,47 +5,39 @@ from entmax import entmax15, sparsemax, entmax_bisect -# These tests only work on cuda, so the first test will fail if you do not have it - - -def test_cuda_available(): - assert torch.cuda.is_available() +def make_negatives(dtype, max_pow): + negatives = [] + for i in range(2, max_pow + 1): + negative = torch.randn(128, dtype=dtype, device="cuda") - 10 ** i + negative[0] += 5 + negatives.append(negative) + return negatives if torch.cuda.is_available(): mappings = [entmax15, sparsemax, partial(entmax_bisect, alpha=1.5), partial(entmax_bisect, alpha=2)] - # make data - long_vecs = [ - torch.randn(32000, dtype=torch.float32, device="cuda") + long_bf16 = [ + torch.randn(32000, dtype=torch.bfloat16, device="cuda") for _ in range(5) ] + negatives_bf16 = make_negatives(torch.bfloat16, 5) - negatives = [] - for i in range(2, 6): - negative = torch.randn(128, dtype=torch.float32, device="cuda") - 10 ** i - negative[0] += 5 - negatives.append(negative) + long_fp16 = [ + torch.randn(32000, dtype=torch.float16, device="cuda") + for _ in range(5) + ] + negatives_fp16 = make_negatives(torch.float16, 3) - @pytest.mark.parametrize("Xs", (long_vecs, negatives)) + @pytest.mark.parametrize("Xs", (long_bf16, negatives_bf16, long_fp16, negatives_fp16)) @pytest.mark.parametrize("func", mappings) - @pytest.mark.parametrize("dtype", (torch.bfloat16, torch.float16)) - def test_sum_one(Xs, func, dtype): - _Xs = [X.to(dtype) for X in Xs] - with torch.autocast(device_type="cuda", dtype=dtype): - for _X in _Xs: - scores = func(_X, dim=-1) - prob_mass = scores.sum(-1) - assert torch.allclose(prob_mass, torch.tensor([1.0], dtype=prob_mass.dtype, device="cuda")) + def test_probs_close(Xs, func): + dtype = Xs[0].dtype - @pytest.mark.parametrize("Xs", (long_vecs, negatives)) - @pytest.mark.parametrize("func", mappings) - @pytest.mark.parametrize("dtype", (torch.bfloat16, torch.float16)) - def test_probs_close(Xs, func, dtype): - full_precision_probs = [func(X.to(dtype).to(torch.float32), dim=-1) for X in Xs] + full_precision_probs = [func(X.to(torch.float32), dim=-1) for X in Xs] _Xs = [X.to(dtype) for X in Xs] with torch.autocast(device_type="cuda", dtype=dtype): for _X, fpp in zip(_Xs, full_precision_probs): probs = func(_X, dim=-1) - assert torch.allclose(probs, fpp.to(probs.dtype)) + assert torch.allclose(probs, fpp.to(dtype)) From 51ed45355ccd317598332e39033df6ee96ffb112 Mon Sep 17 00:00:00 2001 From: Ben Peters Date: Fri, 27 Oct 2023 13:58:19 +0000 Subject: [PATCH 15/15] finally fix tests --- entmax/test_amp.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/entmax/test_amp.py b/entmax/test_amp.py index 1d30a9e..afd5509 100644 --- a/entmax/test_amp.py +++ b/entmax/test_amp.py @@ -4,6 +4,7 @@ from entmax import entmax15, sparsemax, entmax_bisect +torch.manual_seed(42) def make_negatives(dtype, max_pow): negatives = [] @@ -22,13 +23,13 @@ def make_negatives(dtype, max_pow): torch.randn(32000, dtype=torch.bfloat16, device="cuda") for _ in range(5) ] - negatives_bf16 = make_negatives(torch.bfloat16, 5) + negatives_bf16 = make_negatives(torch.bfloat16, 7) long_fp16 = [ torch.randn(32000, dtype=torch.float16, device="cuda") for _ in range(5) ] - negatives_fp16 = make_negatives(torch.float16, 3) + negatives_fp16 = make_negatives(torch.float16, 4) @pytest.mark.parametrize("Xs", (long_bf16, negatives_bf16, long_fp16, negatives_fp16)) @pytest.mark.parametrize("func", mappings) @@ -40,4 +41,4 @@ def test_probs_close(Xs, func): with torch.autocast(device_type="cuda", dtype=dtype): for _X, fpp in zip(_Xs, full_precision_probs): probs = func(_X, dim=-1) - assert torch.allclose(probs, fpp.to(dtype)) + assert torch.allclose(probs, fpp)