-
Notifications
You must be signed in to change notification settings - Fork 1
Open
Description
Thanks for the great work.
seq_len = 128
batch_size = 2
device = "cuda"
q = torch.randn(batch_size, 8, seq_len, 64, device=device).requires_grad_(True)
k = torch.randn(batch_size, 8, seq_len, 64, device=device).requires_grad_(True)
v = torch.randn(batch_size, 8, seq_len, 64, device=device).requires_grad_(True)
out = adasplash(q, k, v, alpha=1.5, niter=10, is_causal=False)
out.sum().backward()Occasionally, I could get the following error when running the above code:
File adasplash_block_mask.py:312, in compute_bidxs_and_cubcounts(bmask, B, N_H, mblocks, nblocks, NEED_BACKWARD, device)
310 if NEED_BACKWARD:
311 kv_cubcount = torch.zeros((B * N_H * nblocks + 1,), device=device, dtype=torch.int32) # fmt: skip
--> [312]kv_bidxs = bmask.nonzero(as_tuple=True)[3].to(torch.int16)
313 torch.cumsum(bmask.sum(dim=-1).flatten(), dim=0, out=kv_cubcount[1:])
315 q_cubcount = torch.zeros((B * N_H * mblocks + 1,), device=device, dtype=torch.int32)
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.Could you do me a favor ?
Metadata
Metadata
Assignees
Labels
No labels