Skip to content

AssertionError during GPU fitting in phlash: assert np.all(ll < 0) fails #27

@yanyanyanyanyan1

Description

@yanyanyanyanyan1

Hello,

I encountered an error while running phlash on a large Snow Leopard VCF dataset. The GPU fitting fails with an AssertionError related to the log-likelihood values. Below are the details of my environment, code, and error message.

  1. Environment:

Server: login01

OS: Linux

Python: 3.12 (mambaforge base environment)

phlash version: (installed via pip, latest stable)

CUDA: 12.2

GPU: 1 x NVIDIA A100 (compute capability 8.9)

JAX: installed with GPU support

  1. Slurm script (run_phlash.slurm):

#!/bin/bash

#SBATCH --job-name=phlash_fit
#SBATCH --partition=ksagnormal01
#SBATCH -N 1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=8
#SBATCH --gres=gpu:1
#SBATCH --mem=24GB
#SBATCH --time=12:00:00
#SBATCH --output=phlash_%j.log

module purge
module load nvidia/cuda/12.2 compiler/gcc/9.3.0 mpi/openmpi/openmpi-4.1.5-gcc9.3.0

source ~/mambaforge/bin/activate base

python runphlash.py

echo "Job finished at $(date)"

  1. Python script (runphlash.py):

-- coding: utf-8 --

import os
import phlash
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import pickle
import pysam

def main():
vcf_base = "/public/home/yzzzz/phlash/Snow_leopard/"
template = "rename.vcf.gz"
result_file = os.path.join(vcf_base, "phlash_results.pkl")
vcf_path = os.path.join(vcf_base, template)

vcf = pysam.VariantFile(vcf_path)
samples_all = list(vcf.header.samples)
chrom_lengths = {contig: vcf.header.contigs[contig].length for contig in vcf.header.contigs}

chroms_all = []
for chrom_name, chrom_len in chrom_lengths.items():
    region_str = f"{chrom_name}:1-{chrom_len}"
    chroms_all.append(phlash.contig(vcf_path, samples=samples_all, region=region_str))

mutation_rate = 4.94E-09

# Fitting the model
results = phlash.fit(chroms_all, mutation_rate=mutation_rate)
print("Done!")

with open(result_file, "wb") as f:
    pickle.dump(results, f)

if name == "main":
import multiprocessing
multiprocessing.freeze_support()
main()

  1. Log excerpt (phlash_100106084.log):

.....
Fitting model: 2%|▎ | 25/1000 [00:48<05:58, 2.72it/s]
Fitting model: 3%|▎ | 26/1000 [00:48<06:07, 2.65it/s]ERROR:2025-10-17 15:36:53,074:jax._src.callback:102: jax.pure_callback failed
Traceback (most recent call last):
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/callback.py", line 100, in pure_callback_impl
return tree_util.tree_map(np.asarray, callback(*args))
^^^^^^^^^^^^^^^
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/callback.py", line 77, in call
return tree_util.tree_leaves(self.callback_func(*args, **kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/phlash/gpu.py", line 411, in call
self._compute_on_gpu(*args)
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/phlash/gpu.py", line 441, in _compute_on_gpu
results[i] = gpu_kernel(pp, split_index, grad, barrier)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/phlash/gpu.py", line 295, in call
assert np.all(ll < 0)
^^^^^^^^^^^^^^
AssertionError
E1017 15:36:54.402556 18062 pjrt_stream_executor_client.cc:3077] Execution of replica 0 failed: INTERNAL: CustomCall failed: CpuCallback error: Traceback (most recent call last):
File "/public/home/yzzzz/phlash/Snow_leopard/runphlash.py", line 80, in
File "/public/home/yzzzz/phlash/Snow_leopard/runphlash.py", line 35, in main
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/phlash/mcmc.py", line 278, in fit
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/pjit.py", line 339, in cache_miss
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/pjit.py", line 194, in _python_pjit_helper
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/pjit.py", line 1681, in _pjit_call_impl_python
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/profiler.py", line 334, in wrapper
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 1288, in call
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/callback.py", line 778, in _wrapped_callback
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/callback.py", line 224, in _callback
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/callback.py", line 103, in pure_callback_impl
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/callback.py", line 77, in call
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/phlash/gpu.py", line 411, in call
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/phlash/gpu.py", line 441, in _compute_on_gpu
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/phlash/gpu.py", line 295, in call
AssertionError:

Fitting model: 3%|▎ | 26/1000 [00:50<31:25, 1.94s/it]
Traceback (most recent call last):
File "/public/home/yzzzz/phlash/Snow_leopard/runphlash.py", line 80, in
main()
File "/public/home/yzzzz/phlash/Snow_leopard/runphlash.py", line 35, in main
results = phlash.fit(chroms_all, mutation_rate=mutation_rate)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/phlash/mcmc.py", line 278, in fit
state1 = step(state, **kw)
^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: CustomCall failed: CpuCallback error: Traceback (most recent call last):
File "/public/home/yzzzz/phlash/Snow_leopard/runphlash.py", line 80, in
File "/public/home/yzzzz/phlash/Snow_leopard/runphlash.py", line 35, in main
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/phlash/mcmc.py", line 278, in fit
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/pjit.py", line 339, in cache_miss
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/pjit.py", line 194, in _python_pjit_helper
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/pjit.py", line 1681, in _pjit_call_impl_python
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/profiler.py", line 334, in wrapper
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 1288, in call
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/callback.py", line 778, in _wrapped_callback
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/callback.py", line 224, in _callback
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/callback.py", line 103, in pure_callback_impl
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/jax/_src/callback.py", line 77, in call
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/phlash/gpu.py", line 411, in call
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/phlash/gpu.py", line 441, in _compute_on_gpu
File "/public/home/yzzzz/mambaforge/lib/python3.12/site-packages/phlash/gpu.py", line 295, in call
AssertionError:

For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
Job finished at Fri Oct 17 15:37:04 CST 2025

2025-10-17 15:36:53,074 | jax._src.callback | ERROR: jax.pure_callback failed
...
File "phlash/gpu.py", line 295, in call
assert np.all(ll < 0)
AssertionError
...
XlaRuntimeError: INTERNAL: CustomCall failed: CpuCallback error

  1. Observations:

The error happens after about 3% of MCMC iterations.

Before the error, the log shows that the chunk was downsampled from 1.62GB to 0.02GB, and the minibatch size is 5.

There is also a warning: The chunk size is 35, which is less than 10 times the overlap (500).

The error does not provide information about which chunk or which ll value triggered the assertion.

  1. Steps I tried:

Checked the VCF format, it is bgzipped and indexed.

Attempted to run the script with CPU only by setting CUDA_VISIBLE_DEVICES='' and use_gpu=False.

  1. Questions:

Is this AssertionError caused by GPU numerical issues, small chunk/minibatch size, or something wrong in my data?

How should I adjust the chunk size, overlap, or minibatch to avoid this error?

Is there a way to debug which ll value is causing the assertion failure?

Thank you for your help!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions