Skip to content

Help on AMD/ROCm/hip #79

@flashburns

Description

@flashburns

Multi-Hardware Support: RWKV-PEFT officially supports NVIDIA, AMD, Moore Threads, Musa, Iluvatar CoreX, and other hardware platforms. Ascend NPU implementation will be available later. Note: Currently we only support issues for NVIDIA hardware

I understand that this issue is not for NVIDIA hardware, so if you just end up closing it that's ok, I just wanted to document the issue I ran into, because you do list AMD as supported hardware.

Here is the train paramaters/script I am using:

#!/usr/bin/env bash
set -e

load_model=/root/lilith_rwkv/rwkv7-g1b-2.9b-20251205-ctx8192.pth
proj_dir=/root/lilith_rwkv/out
data_file=/root/lilith_rwkv/lilith_rwkv

# convert data
# json2bin -i $data_file.jsonl

# do the fine tune

# 2.9B
n_layer=32
n_embd=2560

micro_bsz=8
epoch_save=1
epoch_steps=200
ctx_len=1024
peft_config='{"r":64,"lora_alpha":128,"lora_dropout":0.01}'

cd /root/RWKV-PEFT

python3 train.py --load_model $load_model \
	--proj_dir $proj_dir --data_file $data_file \
	--vocab_size 65536 \
	--data_type binidx \
	--n_layer $n_layer --n_embd $n_embd \
	--ctx_len $ctx_len --micro_bsz $micro_bsz \
	--epoch_steps $epoch_steps --epoch_count 10 --epoch_save $epoch_save \
	--lr_init 1e-5 --lr_final 1e-5 \
	--accelerator gpu --precision bf16 \
	--devices 1 --strategy deepspeed_stage_1 --grad_cp 1 \
	--my_testing "x070" \
	--peft lora --peft_config $peft_config

I have tried also adding --op cuda/fla and --fla This does not seem to improve anything.

I tried editing rwkvt/operator/rwkvop.py to remove the unknown flags being passed to the AMD compiler, but it looks like a deeper fix involving the actual kernels is needed.
rwkvop.py
The commit at the time of this edit 5704c39f8ab1d2ac63936ab392aadb6ba526e1a5
Error output with rwkvop.py edits: output.txt

Here is the original output without the modified rwkvop.py:

########## work in progress ##########
/opt/venv/lib/python3.13/site-packages/torch/library.py:356: UserWarning: Warning only once for all operators,  other operators may also be overridden.
  Overriding a previously registered kernel for the same operator and the same dispatch key
  operator: flash_attn::_flash_attn_backward(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor(a6!)? dq, Tensor(a7!)? dk, Tensor(a8!)? dv, float dropout_p, float softmax_scale, bool causal, SymInt window_size_left, SymInt window_size_right, float softcap, Tensor? alibi_slopes, bool deterministic, Tensor? rng_state=None) -> Tensor
    registered at /opt/venv/lib/python3.13/site-packages/torch/_library/custom_ops.py:922
  dispatch key: ADInplaceOrView
  previous kernel: no debug info
       new kernel: registered at /opt/venv/lib/python3.13/site-packages/torch/_library/custom_ops.py:922 (Triggered internally at /__w/TheRock/TheRock/external-builds/pytorch/pytorch/aten/src/ATen/core/dispatch/OperatorEntry.cpp:208.)
  self.m.impl(
/opt/venv/lib/python3.13/site-packages/torch/backends/__init__.py:46: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /__w/TheRock/TheRock/external-builds/pytorch/pytorch/aten/src/ATen/Context.cpp:45.)
  self.setter(val)
########## WKV OP           cuda               ##########

########## FUSED OP    False          ##########

RWKV_MY_TESTING x070
/root/RWKV-PEFT/cuda/rwkv7_clampw.cu -> /root/RWKV-PEFT/hip/rwkv7_clampw.hip [skipped, already hipified]
/root/RWKV-PEFT/cuda/rwkv7_clampw.cpp -> /root/RWKV-PEFT/hip/rwkv7_clampw.cpp [skipped, already hipified]
�[92mSuccessfully preprocessed all matching files.�[0m
Total number of unsupported CUDA function calls: 0


Total number of replaced kernel launches: 2
[1/3] /opt/venv/bin/hipcc  -DWITH_HIP -DTORCH_EXTENSION_NAME=rwkv7_clampw -DTORCH_API_INCLUDE_EXTENSION_H -isystem /opt/venv/lib/python3.13/site-packages/torch/include -isystem /opt/venv/lib/python3.13/site-packages/torch/include/torch/csrc/api/include -isystem /opt/venv/lib/python3.13/site-packages/torch/include/THH -isystem /opt/venv/include -isystem /usr/include/python3.13 -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -fPIC -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -DHIP_ENABLE_WARP_SYNC_BUILTINS=1 -std=c++17 --offload-arch=gfx1151 -fno-gpu-rdc -res-usage -D_N_=64 -D_CHUNK_LEN_=16 --use_fast_math -O3 -Xptxas -O3 --extra-device-vectorization -c /root/RWKV-PEFT/hip/rwkv7_clampw.hip -o rwkv7_clampw.cuda.o 
FAILED: [code=1] rwkv7_clampw.cuda.o 
/opt/venv/bin/hipcc  -DWITH_HIP -DTORCH_EXTENSION_NAME=rwkv7_clampw -DTORCH_API_INCLUDE_EXTENSION_H -isystem /opt/venv/lib/python3.13/site-packages/torch/include -isystem /opt/venv/lib/python3.13/site-packages/torch/include/torch/csrc/api/include -isystem /opt/venv/lib/python3.13/site-packages/torch/include/THH -isystem /opt/venv/include -isystem /usr/include/python3.13 -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -fPIC -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -DHIP_ENABLE_WARP_SYNC_BUILTINS=1 -std=c++17 --offload-arch=gfx1151 -fno-gpu-rdc -res-usage -D_N_=64 -D_CHUNK_LEN_=16 --use_fast_math -O3 -Xptxas -O3 --extra-device-vectorization -c /root/RWKV-PEFT/hip/rwkv7_clampw.hip -o rwkv7_clampw.cuda.o 
clang++: error: unknown argument: '-res-usage'
clang++: error: unknown argument: '--use_fast_math'
clang++: error: unknown argument: '-Xptxas'
clang++: error: unknown argument: '--extra-device-vectorization'
failed to execute:/opt/venv/lib/python3.13/site-packages/_rocm_sdk_core/lib/llvm/bin/clang++  --offload-arch=gfx1151  -DWITH_HIP -DTORCH_EXTENSION_NAME=rwkv7_clampw -DTORCH_API_INCLUDE_EXTENSION_H -isystem /opt/venv/lib/python3.13/site-packages/torch/include -isystem /opt/venv/lib/python3.13/site-packages/torch/include/torch/csrc/api/include -isystem /opt/venv/lib/python3.13/site-packages/torch/include/THH -isystem /opt/venv/include -isystem /usr/include/python3.13 -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -fPIC -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -DHIP_ENABLE_WARP_SYNC_BUILTINS=1 -std=c++17 -fno-gpu-rdc -res-usage -D_N_=64 -D_CHUNK_LEN_=16 --use_fast_math -O3 -Xptxas -O3 --extra-device-vectorization -c -x hip /root/RWKV-PEFT/hip/rwkv7_clampw.hip -o "rwkv7_clampw.cuda.o"
[2/3] c++ -MMD -MF rwkv7_clampw.o.d -DTORCH_EXTENSION_NAME=rwkv7_clampw -DTORCH_API_INCLUDE_EXTENSION_H -isystem /opt/venv/lib/python3.13/site-packages/torch/include -isystem /opt/venv/lib/python3.13/site-packages/torch/include/torch/csrc/api/include -isystem /opt/venv/lib/python3.13/site-packages/torch/include/THH -isystem /opt/venv/include -isystem /usr/include/python3.13 -fPIC -std=c++17 -c /root/RWKV-PEFT/hip/rwkv7_clampw.cpp -o rwkv7_clampw.o -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -fPIC
FAILED: [code=1] rwkv7_clampw.o 
c++ -MMD -MF rwkv7_clampw.o.d -DTORCH_EXTENSION_NAME=rwkv7_clampw -DTORCH_API_INCLUDE_EXTENSION_H -isystem /opt/venv/lib/python3.13/site-packages/torch/include -isystem /opt/venv/lib/python3.13/site-packages/torch/include/torch/csrc/api/include -isystem /opt/venv/lib/python3.13/site-packages/torch/include/THH -isystem /opt/venv/include -isystem /usr/include/python3.13 -fPIC -std=c++17 -c /root/RWKV-PEFT/hip/rwkv7_clampw.cpp -o rwkv7_clampw.o -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -fPIC
/root/RWKV-PEFT/hip/rwkv7_clampw.cpp:7:14: fatal error: hip/hip_bf16.h: No such file or directory
    7 |     #include <hip/hip_bf16.h>
      |              ^~~~~~~~~~~~~~~~
compilation terminated.
ninja: build stopped: subcommand failed.
Traceback (most recent call last):
  File "/opt/venv/lib/python3.13/site-packages/torch/utils/cpp_extension.py", line 2620, in _run_ninja_build
    subprocess.run(
    ~~~~~~~~~~~~~~^
        command,
        ^^^^^^^^
    ...<4 lines>...
        check=True,
        ^^^^^^^^^^^
        env=env)
        ^^^^^^^^
  File "/usr/lib/python3.13/subprocess.py", line 577, in run
    raise CalledProcessError(retcode, process.args,
                             output=stdout, stderr=stderr)
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/root/RWKV-PEFT/train.py", line 253, in <module>
    from rwkvt.peft_loading import load_peft_model
  File "/root/RWKV-PEFT/rwkvt/peft_loading.py", line 5, in <module>
    from rwkvt.lightning_train.light_rwkv import RWKV
  File "/root/RWKV-PEFT/rwkvt/lightning_train/light_rwkv.py", line 31, in <module>
    from rwkvt.rwkv7.model import RWKV7 as RWKVModel
  File "/root/RWKV-PEFT/rwkvt/rwkv7/model.py", line 11, in <module>
    from .block import Block
  File "/root/RWKV-PEFT/rwkvt/rwkv7/block.py", line 4, in <module>
    from .att import RWKV_Tmix_v7
  File "/root/RWKV-PEFT/rwkvt/rwkv7/att.py", line 8, in <module>
    from rwkvt.operator.rwkvop import RUN_CUDA_RWKV7g, RUN_RWKV7_STATE, RUN_RWKV7_INFCTX
  File "/root/RWKV-PEFT/rwkvt/operator/rwkvop.py", line 380, in <module>
    load(name="rwkv7_clampw", sources=[f'cuda/rwkv7_clampw.cu', 'cuda/rwkv7_clampw.cpp'], is_python_module=False, verbose=True, extra_cuda_cflags=flags)
    ~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/venv/lib/python3.13/site-packages/torch/utils/cpp_extension.py", line 1710, in load
    return _jit_compile(
        name,
    ...<11 lines>...
        is_standalone,
        keep_intermediates=keep_intermediates)
  File "/opt/venv/lib/python3.13/site-packages/torch/utils/cpp_extension.py", line 2152, in _jit_compile
    _write_ninja_file_and_build_library(
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
        name=name,
        ^^^^^^^^^^
    ...<9 lines>...
        with_sycl=with_sycl,
        ^^^^^^^^^^^^^^^^^^^^
        is_standalone=is_standalone)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/venv/lib/python3.13/site-packages/torch/utils/cpp_extension.py", line 2304, in _write_ninja_file_and_build_library
    _run_ninja_build(
    ~~~~~~~~~~~~~~~~^
        build_directory,
        ^^^^^^^^^^^^^^^^
        verbose,
        ^^^^^^^^
        error_prefix=f"Error building extension '{name}'")
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/venv/lib/python3.13/site-packages/torch/utils/cpp_extension.py", line 2637, in _run_ninja_build
    raise RuntimeError(message) from e
RuntimeError: Error building extension 'rwkv7_clampw'
[W116 23:59:13.476275198 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator())

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions