-
Notifications
You must be signed in to change notification settings - Fork 72
Description
Description:
The detector_bayesian.BayesianDetector.train_best_detector method raises a ValueError: We have found the training unstable on CPUs... Use GPU or TPU for training. even when a GPU is available, detected by JAX, and the correct torch.device object is passed.
Environment:
- Operating System: Windows 11 24H2 with WSL2 (Ubuntu 22.04)
- Python Version: 3.12.3
- Key Library Versions (from pip freeze):
synthid-text: 0.2.1jax: 0.5.3jaxlib: 0.5.3jax-cuda12-plugin: 0.5.3flax: 0.10.5torch: 2.4.0transformers: 4.51.0
- CUDA Toolkit Version: 12.7
- GPU Model: NVIDIA RTX 4060
(Full pip freeze output attached below if needed by maintainers)
pip freeze output
absl-py==2.2.2
accelerate==1.6.0
annotated-types==0.7.0
anyio==4.9.0
bitsandbytes==0.45.4
certifi==2025.1.31
charset-normalizer==3.4.1
chex==0.1.89
click==8.1.8
etils==1.12.2
fastapi==0.115.12
filelock==3.18.0
flax==0.10.5
fsspec==2025.3.2
h11==0.14.0
huggingface-hub==0.30.1
humanize==4.12.2
idna==3.10
immutabledict==4.2.0
importlib_resources==6.5.2
jax==0.5.3
jax-cuda12-pjrt==0.5.3
jax-cuda12-plugin==0.5.3
jaxlib==0.5.3
jaxtyping==0.3.1
Jinja2==3.1.6
joblib==1.4.2
markdown-it-py==3.0.0
MarkupSafe==3.0.2
mdurl==0.1.2
ml_dtypes==0.5.1
mpmath==1.3.0
msgpack==1.1.0
nest-asyncio==1.6.0
networkx==3.4.2
numpy==2.2.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvcc-cu12==12.8.93
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-cusparselt-cu12==0.6.2
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.1.105
opt_einsum==3.4.0
optax==0.2.4
orbax-checkpoint==0.11.10
packaging==24.2
pandas==2.2.3
protobuf==6.30.2
psutil==7.0.0
pydantic==2.11.2
pydantic_core==2.33.1
Pygments==2.19.1
python-dateutil==2.9.0.post0
python-dotenv==1.1.0
pytz==2025.2
PyYAML==6.0.2
regex==2024.11.6
requests==2.32.3
rich==14.0.0
safetensors==0.5.3
scikit-learn==1.6.1
scipy==1.15.2
setuptools==78.1.0
simplejson==3.20.1
six==1.17.0
sniffio==1.3.1
starlette==0.46.1
sympy==1.13.1
synthid-text==0.2.1
tensorstore==0.1.73
threadpoolctl==3.6.0
tokenizers==0.21.1
toolz==1.0.0
torch==2.4.0
tqdm==4.67.1
transformers==4.51.0
treescope==0.1.9
triton==3.0.0
typing-inspection==0.4.0
typing_extensions==4.13.1
tzdata==2025.2
urllib3==2.3.0
uv==0.6.12
uvicorn==0.34.0
wadler_lindig==0.1.4
zipp==3.21.0
Bug Details:
- The environment is configured with CUDA, and JAX correctly identifies the GPU (verified via
jax.devices()):JAX available devices: [CudaDevice(id=0)] - The
train_best_detectormethod is called withtorch_devicecorrectly set totorch.device('cuda'). Verified with debug prints:DEBUG: About to call train_best_detector with: torch_device = cuda type(torch_device) = <class 'torch.device'> torch_device.type = cuda - Despite the above, the function raises the
ValueError: We have found the training unstable on CPUs... Use GPU or TPU for training.
Root Cause:
The check within the train_best_detector method in src/synthid_text/detector_bayesian.py (around line 1032 in version 0.2.1) appears to have inverted logic:
# From src/synthid_text/detector_bayesian.py
if torch_device.type in ("cuda", "tpu"): # Checks if device IS cuda or tpu
raise ValueError( # RAISES ERROR IF IT IS!
"We have found the training unstable on CPUs; we are working on"
" a fix. Use GPU or TPU for training."
)This condition incorrectly causes the error to be raised when a CUDA or TPU device is detected and passed, instead of when a CPU is detected.
Proposed Fix:
Changing the condition to check not in resolves the issue:
- if torch_device.type in ("cuda", "tpu"):
+ if torch_device.type not in ("cuda", "tpu"):
raise ValueError(
"We have found the training unstable on CPUs; we are working on"
" a fix. Use GPU or TPU for training."Applying this local modification resolved the incorrect ValueError regarding CPU usage.