-
Notifications
You must be signed in to change notification settings - Fork 5
Description
Hello,
I encountered an error while running phlash on a large Snow Leopard VCF dataset. The GPU fitting fails with an AssertionError related to the log-likelihood values. Below are the details of my environment, code, and error message.
- Environment:
Server: login01
OS: Linux
Python: 3.12 (mambaforge base environment)
phlash version: (installed via pip, latest stable)
CUDA: 12.2
GPU: 1 x NVIDIA A100 (compute capability 8.9)
JAX: installed with GPU support
- Slurm script (run_phlash.slurm):
#!/bin/bash
#SBATCH --job-name=phlash_fit
#SBATCH --partition=ksagnormal01
#SBATCH -N 1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=8
#SBATCH --gres=gpu:1
#SBATCH --mem=24GB
#SBATCH --time=12:00:00
#SBATCH --output=phlash_%j.log
module purge
module load nvidia/cuda/12.2 compiler/gcc/9.3.0 mpi/openmpi/openmpi-4.1.5-gcc9.3.0
source ~/mambaforge/bin/activate base
python runphlash.py
echo "Job finished at $(date)"
- Python script (runphlash.py):
-- coding: utf-8 --
import os
import phlash
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import pickle
import pysam
def main():
vcf_base = "/public/home/yzzzz/phlash/Snow_leopard/"
template = "rename.vcf.gz"
result_file = os.path.join(vcf_base, "phlash_results.pkl")
vcf_path = os.path.join(vcf_base, template)
vcf = pysam.VariantFile(vcf_path)
samples_all = list(vcf.header.samples)
chrom_lengths = {contig: vcf.header.contigs[contig].length for contig in vcf.header.contigs}
chroms_all = []
for chrom_name, chrom_len in chrom_lengths.items():
region_str = f"{chrom_name}:1-{chrom_len}"
chroms_all.append(phlash.contig(vcf_path, samples=samples_all, region=region_str))
mutation_rate = 4.94E-09
# Fitting the model
results = phlash.fit(chroms_all, mutation_rate=mutation_rate)
print("Done!")
with open(result_file, "wb") as f:
pickle.dump(results, f)
if name == "main":
import multiprocessing
multiprocessing.freeze_support()
main()
- Log excerpt (phlash_100106084.log):
.....
Fitting model: 2%|▎ | 25/1000 [00:48<05:58, 2.72it/s]
Fitting model: 3%|▎ | 26/1000 [00:48<06:07, 2.65it/s]ERROR:2025-10-17 15:36:53,074:jax._src.callback:102: jax.pure_callback failed
Traceback (most recent call last):
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/callback.py", line 100, in pure_callback_impl
return tree_util.tree_map(np.asarray, callback(*args))
^^^^^^^^^^^^^^^
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/callback.py", line 77, in call
return tree_util.tree_leaves(self.callback_func(*args, **kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/phlash/gpu.py", line 411, in call
self._compute_on_gpu(*args)
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/phlash/gpu.py", line 441, in _compute_on_gpu
results[i] = gpu_kernel(pp, split_index, grad, barrier)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/phlash/gpu.py", line 295, in call
assert np.all(ll < 0)
^^^^^^^^^^^^^^
AssertionError
E1017 15:36:54.402556 18062 pjrt_stream_executor_client.cc:3077] Execution of replica 0 failed: INTERNAL: CustomCall failed: CpuCallback error: Traceback (most recent call last):
File "/public/home/yzzzz/phlash/Snow_leopard/runphlash.py", line 80, in
File "/public/home/yzzzz/phlash/Snow_leopard/runphlash.py", line 35, in main
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/phlash/mcmc.py", line 278, in fit
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/pjit.py", line 339, in cache_miss
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/pjit.py", line 194, in _python_pjit_helper
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/pjit.py", line 1681, in _pjit_call_impl_python
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/profiler.py", line 334, in wrapper
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 1288, in call
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/callback.py", line 778, in _wrapped_callback
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/callback.py", line 224, in _callback
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/callback.py", line 103, in pure_callback_impl
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/callback.py", line 77, in call
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/phlash/gpu.py", line 411, in call
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/phlash/gpu.py", line 441, in _compute_on_gpu
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/phlash/gpu.py", line 295, in call
AssertionError:
Fitting model: 3%|▎ | 26/1000 [00:50<31:25, 1.94s/it]
Traceback (most recent call last):
File "/public/home/yzzzz/phlash/Snow_leopard/runphlash.py", line 80, in
main()
File "/public/home/yzzzz/phlash/Snow_leopard/runphlash.py", line 35, in main
results = phlash.fit(chroms_all, mutation_rate=mutation_rate)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/phlash/mcmc.py", line 278, in fit
state1 = step(state, **kw)
^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: CustomCall failed: CpuCallback error: Traceback (most recent call last):
File "/public/home/yzzzz/phlash/Snow_leopard/runphlash.py", line 80, in
File "/public/home/yzzzz/phlash/Snow_leopard/runphlash.py", line 35, in main
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/phlash/mcmc.py", line 278, in fit
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/pjit.py", line 339, in cache_miss
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/pjit.py", line 194, in _python_pjit_helper
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/pjit.py", line 1681, in _pjit_call_impl_python
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/profiler.py", line 334, in wrapper
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 1288, in call
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/callback.py", line 778, in _wrapped_callback
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/callback.py", line 224, in _callback
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/callback.py", line 103, in pure_callback_impl
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/callback.py", line 77, in call
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/phlash/gpu.py", line 411, in call
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/phlash/gpu.py", line 441, in _compute_on_gpu
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/phlash/gpu.py", line 295, in call
AssertionError:
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
Job finished at Fri Oct 17 15:37:04 CST 2025
2025-10-17 15:36:53,074 | jax._src.callback | ERROR: jax.pure_callback failed
...
File "phlash/gpu.py", line 295, in call
assert np.all(ll < 0)
AssertionError
...
XlaRuntimeError: INTERNAL: CustomCall failed: CpuCallback error
- Observations:
The error happens after about 3% of MCMC iterations.
Before the error, the log shows that the chunk was downsampled from 1.62GB to 0.02GB, and the minibatch size is 5.
There is also a warning: The chunk size is 35, which is less than 10 times the overlap (500).
The error does not provide information about which chunk or which ll value triggered the assertion.
- Steps I tried:
Checked the VCF format, it is bgzipped and indexed.
Attempted to run the script with CPU only by setting CUDA_VISIBLE_DEVICES='' and use_gpu=False.
- Questions:
Is this AssertionError caused by GPU numerical issues, small chunk/minibatch size, or something wrong in my data?
How should I adjust the chunk size, overlap, or minibatch to avoid this error?
Is there a way to debug which ll value is causing the assertion failure?
Thank you for your help!