Skip to content

ValueError in detector_bayesian.train_best_detector due to inverted GPU check #26

@amanmprojects

Description

@amanmprojects

Description:

The detector_bayesian.BayesianDetector.train_best_detector method raises a ValueError: We have found the training unstable on CPUs... Use GPU or TPU for training. even when a GPU is available, detected by JAX, and the correct torch.device object is passed.

Environment:

  • Operating System: Windows 11 24H2 with WSL2 (Ubuntu 22.04)
  • Python Version: 3.12.3
  • Key Library Versions (from pip freeze):
    • synthid-text: 0.2.1
    • jax: 0.5.3
    • jaxlib: 0.5.3
    • jax-cuda12-plugin: 0.5.3
    • flax: 0.10.5
    • torch: 2.4.0
    • transformers: 4.51.0
  • CUDA Toolkit Version: 12.7
  • GPU Model: NVIDIA RTX 4060

(Full pip freeze output attached below if needed by maintainers)

pip freeze output
absl-py==2.2.2
accelerate==1.6.0
annotated-types==0.7.0
anyio==4.9.0
bitsandbytes==0.45.4
certifi==2025.1.31
charset-normalizer==3.4.1
chex==0.1.89
click==8.1.8
etils==1.12.2
fastapi==0.115.12
filelock==3.18.0
flax==0.10.5
fsspec==2025.3.2
h11==0.14.0
huggingface-hub==0.30.1
humanize==4.12.2
idna==3.10
immutabledict==4.2.0
importlib_resources==6.5.2
jax==0.5.3
jax-cuda12-pjrt==0.5.3
jax-cuda12-plugin==0.5.3
jaxlib==0.5.3
jaxtyping==0.3.1
Jinja2==3.1.6
joblib==1.4.2
markdown-it-py==3.0.0
MarkupSafe==3.0.2
mdurl==0.1.2
ml_dtypes==0.5.1
mpmath==1.3.0
msgpack==1.1.0
nest-asyncio==1.6.0
networkx==3.4.2
numpy==2.2.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvcc-cu12==12.8.93
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-cusparselt-cu12==0.6.2
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.1.105
opt_einsum==3.4.0
optax==0.2.4
orbax-checkpoint==0.11.10
packaging==24.2
pandas==2.2.3
protobuf==6.30.2
psutil==7.0.0
pydantic==2.11.2
pydantic_core==2.33.1
Pygments==2.19.1
python-dateutil==2.9.0.post0
python-dotenv==1.1.0
pytz==2025.2
PyYAML==6.0.2
regex==2024.11.6
requests==2.32.3
rich==14.0.0
safetensors==0.5.3
scikit-learn==1.6.1
scipy==1.15.2
setuptools==78.1.0
simplejson==3.20.1
six==1.17.0
sniffio==1.3.1
starlette==0.46.1
sympy==1.13.1
synthid-text==0.2.1
tensorstore==0.1.73
threadpoolctl==3.6.0
tokenizers==0.21.1
toolz==1.0.0
torch==2.4.0
tqdm==4.67.1
transformers==4.51.0
treescope==0.1.9
triton==3.0.0
typing-inspection==0.4.0
typing_extensions==4.13.1
tzdata==2025.2
urllib3==2.3.0
uv==0.6.12
uvicorn==0.34.0
wadler_lindig==0.1.4
zipp==3.21.0

Bug Details:

  1. The environment is configured with CUDA, and JAX correctly identifies the GPU (verified via jax.devices()):
    JAX available devices: [CudaDevice(id=0)]
    
  2. The train_best_detector method is called with torch_device correctly set to torch.device('cuda'). Verified with debug prints:
    DEBUG: About to call train_best_detector with:
      torch_device = cuda
      type(torch_device) = <class 'torch.device'>
      torch_device.type = cuda
    
  3. Despite the above, the function raises the ValueError: We have found the training unstable on CPUs... Use GPU or TPU for training.

Root Cause:

The check within the train_best_detector method in src/synthid_text/detector_bayesian.py (around line 1032 in version 0.2.1) appears to have inverted logic:

      # From src/synthid_text/detector_bayesian.py
      if torch_device.type in ("cuda", "tpu"): # Checks if device IS cuda or tpu
          raise ValueError(                     # RAISES ERROR IF IT IS!
              "We have found the training unstable on CPUs; we are working on"
              " a fix. Use GPU or TPU for training."
          )

This condition incorrectly causes the error to be raised when a CUDA or TPU device is detected and passed, instead of when a CPU is detected.

Proposed Fix:

Changing the condition to check not in resolves the issue:

-     if torch_device.type in ("cuda", "tpu"):
+     if torch_device.type not in ("cuda", "tpu"):
          raise ValueError(
              "We have found the training unstable on CPUs; we are working on"
              " a fix. Use GPU or TPU for training."

Applying this local modification resolved the incorrect ValueError regarding CPU usage.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions