Skip to content

Memory Access Fault when running certain matrix in WMMA FA! #3

@lahmuller

Description

@lahmuller

Hi, Repeerc, I've pulled your code and successfully ran the bench on my w7900, I but the code went wrong when doing certain tensor multiplication. ZLUDA not installed, but I can run the bench_with_sdpa_BNHD.py and gain the performance improvement. But when I make tensor bigger, the code went wrong. Does this have anything to do with ZLUDA or the problem in cuda code?

Environment

Operating System: Ubuntu 22.04 (zluda is not installed,
Software Version: PyTorch 2.3.0 Python 3.9
GPU: W7900(gfx1100)

Reproduce

I changed the code in bench_with_sdpa_BNHD.py

(B, H, N, D) = (1, 32, 64, 128)
causal = False
dtype = torch.float16
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
n_list = []
flops_ft_list = []
maxmem_ft_list = []
flops_sdp_list = []
maxmem_sdp_list = []
for i in range(1,15,1):
    q_shape = (B, N, H, D)
    v_shape = (B, N, H, D)
    k_shape = (B, N, H, D)
    print(f'B:{B}, H:{H}, SeqLen:{N}, DimHead:{D}')
    q = torch.rand(q_shape, dtype=dtype, device="cuda")  
    k = torch.rand(k_shape, dtype=dtype, device="cuda")  
    v = torch.rand(v_shape, dtype=dtype, device="cuda") 
    r3, flops_ft, max_memory_ft, _ = ftt_rocm(q, k, v)
    r0, flops_sdp, max_memory_sdp, _ = sdp_pt(q, k, v)
    r3 = r3.cpu().to(torch.float32).transpose(1, 2)
    r0 = r0.cpu().to(torch.float32).transpose(1, 2)
    maxdiff = (r0 - r3).abs().max().item()
    print("max diff: ", maxdiff)
    n_list.append(N)
    flops_ft_list.append(flops_ft / 1e12)
    flops_sdp_list.append(flops_sdp / 1e12)
    maxmem_ft_list.append(max_memory_ft)
    maxmem_sdp_list.append(max_memory_sdp)

Error info

The error message of above code

Successfully preprocessed all matching files.
B:1, H:32, SeqLen:64, DimHead:128
ftt_rocm:       exec_time:0.0105, total_tflops:0.32, max_memory:2
/opt/conda/envs/py_3.9/lib/python3.9/contextlib.py:87: FutureWarning: `torch.backends.cuda.sdp_kernel()` is deprecated. In the future, this context manager will be removed. Please see `torch.nn.attention.sdpa_kernel()` for the new context manager, with updated signature.
  self.gen = func(*args, **kwds)
/data/vllm-benchmark/llm/YS_test/flash-attention-v2-RDNA3-minimal-main/bench_with_sdpa_BNHD.py:68: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
  r0 = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=causal)
sdp_pt:         exec_time:0.4190, total_tflops:0.01, max_memory:10
max diff:  0.0009765625
B:1, H:32, SeqLen:64, DimHead:128
ftt_rocm:       exec_time:0.0028, total_tflops:1.18, max_memory:2
sdp_pt:         exec_time:0.0180, total_tflops:0.19, max_memory:10
max diff:  0.0009765625
B:1, H:32, SeqLen:64, DimHead:128
Memory access fault by GPU node-2 (Agent handle: 0x977b2e0) on address 0x7f01fc801000. Reason: Page not present or supervisor privilege.
Aborted (core dumped)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions