Skip to content

bitsandbytes loading models in 8-bit doesn't work on Radeon GPUs (gfx1100) #41

@farshadghodsian

Description

@farshadghodsian

System Info

Platform: Ubuntu 22.04
Python Version: 3.10
Hardware: Threadripper Pro 5975wx, WRX80 motherboard, 128GB RAM, 1x Radeon Pro W7900 GPU
Python Libraries:

accelerate==0.31.0
fastapi==0.111.0
gradio==4.37.2
torch==torch==2.5.0.dev20240705+rocm6.1
transformers==4.37.2

bitsandbytes version = built and installed using instructions in rocm_enabled branch

Reproduction

I was successfully able to get models loading in 4-bit using the ROCm/bitsandbytes and transformers library. To get it working I had to tell PyTorch not to use HIPBLASLT as Radeon GPUs do not support it. This is done by setting the environment variable TORCH_BLAS_PREFER_HIPBLASLT=0. This is an upstream issue with PyTorch and not related to bitsandbytes. While this works for loading models in 4-bit, inference fails when I try to do the same with models in 8-bit. When I run the follow code I get an error even after setting TORCH_BLAS_PREFER_HIPBLASLT=0 (see below):

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_8bit=True)

model_name = "facebook/opt-350m"
model_8bit = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Example input
input_text = "Hello, world!"

# Prepare the input for the model
inputs = tokenizer.encode_plus(
        input_text,
        add_special_tokens=True,
        max_length=512,
        return_attention_mask=True,
        return_tensors='pt'
    )

# Generate text using the model
generated_text = model_8bit.generate(inputs['input_ids'], attention_mask=inputs['attention_mask'])

# Decode the generated text using the tokenizer
decoded_text = tokenizer.decode(generated_text[0], skip_special_tokens=True)

# Print the generated output
print(decoded_text)

Error:

rocblaslt warning: No paths matched /home/amd-user/.local/lib/python3.10/site-packages/torch/lib/hipblaslt/library/*gfx1100*co. Make sure that HIPBLASLT_TENSILE_LIBPATH is set correctly.
A: torch.Size([5, 512]), B: torch.Size([1024, 512]), C: (5, 1024); (lda, ldb, ldc): (c_int(5), c_int(1024), c_int(5)); (m, n, k): (c_int(5), c_int(1024), c_int(512))
error detectedTraceback (most recent call last):
  File "/home/amd-user/Documents/bitsandbytes 8-bit test.py", line 23, in <module>
    generated_text = model_8bit.generate(inputs['input_ids'], attention_mask=inputs['attention_mask'])
  File "/home/amd-user/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/amd-user/.local/lib/python3.10/site-packages/transformers/generation/utils.py", line 1479, in generate
    return self.greedy_search(
  File "/home/amd-user/.local/lib/python3.10/site-packages/transformers/generation/utils.py", line 2340, in greedy_search
    outputs = self(
  File "/home/amd-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1716, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/amd-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1727, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/amd-user/.local/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/amd-user/.local/lib/python3.10/site-packages/transformers/models/opt/modeling_opt.py", line 1145, in forward
    outputs = self.model.decoder(
  File "/home/amd-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1716, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/amd-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1727, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/amd-user/.local/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/amd-user/.local/lib/python3.10/site-packages/transformers/models/opt/modeling_opt.py", line 863, in forward
    inputs_embeds = self.project_in(inputs_embeds)
  File "/home/amd-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1716, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/amd-user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1727, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/amd-user/.local/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/amd-user/.local/lib/python3.10/site-packages/bitsandbytes/nn/modules.py", line 801, in forward
    out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
  File "/home/amd-user/.local/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py", line 559, in matmul
    return MatMul8bitLt.apply(A, B, out, bias, state)
  File "/home/amd-user/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/amd-user/.local/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py", line 398, in forward
    out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
  File "/home/amd-user/.local/lib/python3.10/site-packages/bitsandbytes/functional.py", line 2388, in igemmlt
    raise Exception("cublasLt ran into an error!")
Exception: cublasLt ran into an error!

Expected behavior

The above code works fine when I load the model in 4bit only changing quantization_config = BitsAndBytesConfig(load_in_4bit=True) instead of quantization_config = BitsAndBytesConfig(load_in_8bit=True). This issue also occurs when using HuggingFace's accelerate library when trying to use load_in_8bit=True.

I don't really find myself using 8-bit a whole lot so this is not a priority, but thought I should point out this issue so that others are aware. Loading 4-bit models using bitsandbytes is working without issues.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions