Skip to content

MedGemma 4B p.attn_bias_ptr is not correctly aligned with long system prompt #57

@millionsofluo

Description

@millionsofluo

MedGemma 4B p.attn_bias_ptr is not correctly aligned with long system prompt

This document summarizes a reproducible CUDA/runtime issue observed when running MedGemma 1.5 4B IT locally with a long Chinese system prompt, so it can be turned into a GitHub issue for transformers / the MedGemma model maintainers.


1. Environment

  • GPU: NVIDIA RTX 3090 24G
  • Driver Version: 590.48.01
  • CUDA Version (driver): 13.1
  • OS: Ubuntu 22.04
  • Python: 3.10
  • PyTorch stack:
    • torch==2.5.1+cu121
    • torchvision==0.20.1+cu121
    • torchaudio==2.5.1+cu121
  • Transformers: transformers==4.51.0
  • Model: local copy of google/medgemma-1.5-4b-it
  • Pipeline task: "image-text-to-text"

The model is loaded via the transformers.pipeline API.


2. Model loading code

The model is managed by a small singleton wrapper; relevant part (current version, using bfloat16 on CUDA, float32 on CPU):

import torch
from transformers import pipeline

from src.config import ModelConfig
from src.common.logger import INFO, ERROR


class ModelManager:
    ...

    def load(self) -> None:
        """Load pipeline, only once."""
        if self._pipe is not None:
            return
        device = "cuda" if torch.cuda.is_available() else "cpu"
        # CUDA uses bfloat16, CPU uses float32
        dtype = torch.bfloat16 if device == "cuda" else torch.float32
        try:
            self._pipe = pipeline(
                ModelConfig.PIPELINE_TASK,   # "image-text-to-text"
                model=ModelConfig.MODEL_PATH,
                torch_dtype=dtype,
                device=device,
            )
            INFO(f"Model loaded, device={device}")
        except Exception as e:
            ERROR(f"Failed to load model: {e}", exc_info=True)
            raise

    def get_pipeline(self):
        self.load()
        return self._pipe


model_manager = ModelManager()

ModelConfig.MODEL_PATH points to a local directory containing google/medgemma-1.5-4b-it.


3. Inference call pattern

Inference is executed roughly like this (simplified from src/core/inference.py):

from PIL import Image

from src.core.model_manager import model_manager


def run_inference(
    image: Image.Image | None,
    question: str,
    system_image: str,
    system_text_only: str,
    max_new_tokens: int,
    default_system_image: str,
    default_system_text_only: str,
) -> str:
    question = (question or "").strip()
    if not question:
        return "Please input a question."

    if image is not None:
        system_instruction = (system_image or "").strip() or default_system_image
        user_content = [
            {"type": "text", "text": question},
            {"type": "image", "image": image},
        ]
    else:
        system_instruction = (system_text_only or "").strip() or default_system_text_only
        # For the MedGemma pipeline, we still pass a placeholder image for text-only mode
        img_used = _get_placeholder_image()
        user_content = [
            {"type": "text", "text": question},
            {"type": "image", "image": img_used},
        ]

    messages = [
        {"role": "system", "content": [{"type": "text", "text": system_instruction}]},
        {"role": "user", "content": user_content},
    ]

    model_manager.load()
    pipe = model_manager.get_pipeline()
    out = pipe(text=messages, max_new_tokens=max_new_tokens)
    result = out[0]["generated_text"][-1]["content"]
    return result

The key point is that we always pass messages as text=... with the MedGemma-style multimodal message format.


4. Long system prompt that triggers the error

The problematic system prompt is a relatively long, structured Chinese instruction used only for the image branch. It lives in data/prompt.json as one of the presets:

{
  "default": {
    "image": "你作为医院AI影像辅助系统,角色是资深放射科医生助手,仅辅助临床医生分析影像。请严格遵守:\n- 只基于提供的影像和医生问题作答\n- 输出必须用中文\n- 固定结构:\n  【主要发现】\n  【印象与分析】\n  【建议与注意】\n- 使用谨慎语言:如‘提示可能’‘考虑’‘建议排除’‘需结合临床’\n- 如信息不足,明确写‘影像表现不典型,建议补充XX检查’\n诊断、治疗方案及责任由医生承担。",
    "text_only": "你是一名医学专业助手,作为医院AI辅助工具,仅为医生提供参考意见,不替代临床判断。请用中文给出简洁、专业、有条理的回答。结构建议:问题复述 → 核心解答 → 依据或理由 → 临床建议/注意事项。如果不确定或需更多信息,请明确说明‘建议结合患者完整病史/检查结果进一步评估’。最终诊断与治疗由医生负责。"
  },
  "alternatives": {
    "image": [
      {
        "name": "默认 强化版 (有bug)",
        "prompt": "你是一名经验丰富的影像科医生,作为医院AI辅助系统,仅为临床医生提供参考意见,绝不替代医生诊断。请严格基于提供的医学影像和医生提出的具体问题,按照以下结构用中文输出:\n1. 主要发现(按解剖顺序描述明显异常,正常结构可简略)\n2. 印象/初步分析(可能诊断或鉴别诊断,标注置信度,如‘高度提示’‘可能’‘不能排除’)\n3. 建议(进一步检查、随访或其他临床注意事项)\n\n如果影像质量差、信息不足或超出你的能力范围,请明确说明‘建议结合临床及其他检查综合判断’或‘图像质量影响评估,建议重新拍摄’。诊断与最终决策由医生负责。请保持简洁、专业、客观,避免绝对化表述。"
      },
      ...
    ],
    ...
  }
}

Empirically:

  • When this long image prompt (默认 强化版 (有bug)) is used as the system prompt for the image case (i.e. image is not None), we see the CUDA error below.
  • When we switch to shorter image prompts (e.g. the original Google demo-style prompt, or other shorter templates), the error disappears.
  • When we run text-only (no user image, using system_text_only + a small placeholder image), we do not see this error.

So the failure appears to be sensitive to the total sequence length / shape produced by this long system prompt in combination with an actual image input.


5. Error message / stack trace

The runtime error logged is:

RuntimeError: p.attn_bias_ptr is not correctly aligned

In the application logs it appears as (abbreviated):

[ERROR] 推理失败: p.attn_bias_ptr is not correctly aligned | inference.py:94
Traceback (most recent call last):
  File ".../site-packages/gradio/queueing.py", line ...
  ...
  File "/workspace/medgemma/src/core/inference.py", line 79, in run_inference
    out = pipe(text=messages, max_new_tokens=max_new_tokens)
  File ".../site-packages/transformers/pipelines/base.py", line ...
  ...
RuntimeError: p.attn_bias_ptr is not correctly aligned

There is no custom CUDA code in this project; everything is using the standard transformers pipeline.


6. Observations and experiments

  1. Switching system prompts

    • Long image system prompt (默认 强化版 (有bug)) + real image → error.
    • Shorter image prompts (including the original Google medgemma demo prompt) + same image → no error.
    • Text-only system prompts (no user image; uses a placeholder image) → no error.
  2. Switching dtype to float16 on CUDA

    • When we experimented with forcing torch.float16 for CUDA instead of bfloat16, the error actually became more frequent: almost all prompts started to trigger p.attn_bias_ptr is not correctly aligned.
    • Because of that, we reverted to the bfloat16 configuration shown above.
  3. Model size

    • The issue occurs with MedGemma 1.5 4B IT; we have not yet tried the 27B model.
    • However, this looks like a kernel / alignment bug rather than a capacity/quality issue; the model works fine with shorter prompts and the same image.
  4. No custom attention / fused kernel code

    • We are not using any custom fused attention kernels in this project; everything is the stock transformers pipeline on top of PyTorch 2.5.1 + cu121.

7. What we suspect

Given the above, our working hypothesis is:

  • There is a bug in one of the attention-related CUDA kernels used by the MedGemma 4B pipeline for the "image-text-to-text" task, in the combination of:
    • RTX 3090 (Ampere)
    • CUDA 13.1 driver
    • torch==2.5.1+cu121 / transformers==4.51.0
    • CUDA bfloat16 (and even more so with float16)
    • Certain sequence lengths / shapes induced by this long Chinese system prompt + image tokens.

In other words, the same code path and same image works fine for shorter prompts, but fails with this longer system prompt, which strongly suggests a shape-/length-sensitive alignment issue in the underlying kernel.


8. What we are asking from maintainers

  • Could you please:
    • Confirm whether this is a known issue with MedGemma 4B + CUDA bf16/fp16 on Ampere GPUs?
    • Advise on whether there is a recommended workaround (e.g. specific torch / transformers / xFormers version, or disabling a particular fused kernel) to avoid this error?
    • If needed, we can provide a minimal reproduction script using the code snippets above and a dummy image.

For now, our practical workaround is to:

  • Prefer shorter / more compact system prompts for image-based queries, and
  • Keep using CUDA bfloat16 (instead of forcing float16) on the 3090.

But we would very much like to help get this alignment issue fixed upstream if possible.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions