From ee4c51a63201d9bbb638bb5fe16d38c2c5838d3d Mon Sep 17 00:00:00 2001 From: Geert Vandeweyer Date: Thu, 30 Jan 2025 13:46:18 +0100 Subject: [PATCH 01/11] improved gpu efficiency --- pangolin/.fuse_hidden0000252700000002 | 257 --------- pangolin/model.py | 6 +- pangolin/pangolin.py | 722 +++++++++++++++++++------- 3 files changed, 549 insertions(+), 436 deletions(-) delete mode 100644 pangolin/.fuse_hidden0000252700000002 diff --git a/pangolin/.fuse_hidden0000252700000002 b/pangolin/.fuse_hidden0000252700000002 deleted file mode 100644 index 6c2d773..0000000 --- a/pangolin/.fuse_hidden0000252700000002 +++ /dev/null @@ -1,257 +0,0 @@ -import argparse -from pkg_resources import resource_filename -from pangolin.model import * -import vcf -import gffutils -import pandas as pd -import pyfastx -# import time -# startTime = time.time() - -IN_MAP = np.asarray([[0, 0, 0, 0], - [1, 0, 0, 0], - [0, 1, 0, 0], - [0, 0, 1, 0], - [0, 0, 0, 1]]) - - -def one_hot_encode(seq, strand): - seq = seq.upper().replace('A', '1').replace('C', '2') - seq = seq.replace('G', '3').replace('T', '4').replace('N', '0') - if strand == '+': - seq = np.asarray(list(map(int, list(seq)))) - elif strand == '-': - seq = np.asarray(list(map(int, list(seq[::-1])))) - seq = (5 - seq) % 5 # Reverse complement - return IN_MAP[seq.astype('int8')] - - -def compute_score(ref_seq, alt_seq, strand, d, models): - ref_seq = one_hot_encode(ref_seq, strand).T - ref_seq = torch.from_numpy(np.expand_dims(ref_seq, axis=0)).float() - alt_seq = one_hot_encode(alt_seq, strand).T - alt_seq = torch.from_numpy(np.expand_dims(alt_seq, axis=0)).float() - - if torch.cuda.is_available(): - ref_seq = ref_seq.to(torch.device("cuda")) - alt_seq = alt_seq.to(torch.device("cuda")) - - pangolin = [] - for j in range(4): - score = [] - for model in models[3*j:3*j+3]: - with torch.no_grad(): - ref = model(ref_seq)[0][[1,4,7,10][j],:].cpu().numpy() - alt = model(alt_seq)[0][[1,4,7,10][j],:].cpu().numpy() - if strand == '-': - ref = ref[::-1] - alt = alt[::-1] - l = 2*d+1 - ndiff = np.abs(len(ref)-len(alt)) - if len(ref)>len(alt): - alt = np.concatenate([alt[0:l//2+1],np.zeros(ndiff),alt[l//2+1:]]) - elif len(ref) pos or gene[4] < pos: - continue - gene_id = gene["gene_id"][0] - exons = [] - for exon in gtf.children(gene, featuretype="exon"): - exons.extend([exon[3], exon[4]]) - if gene[6] == '+': - genes_pos[gene_id] = exons - elif gene[6] == '-': - genes_neg[gene_id] = exons - - return (genes_pos, genes_neg) - - -def process_variant(lnum, chr, pos, ref, alt, gtf, models, args): - d = args.distance - cutoff = args.score_cutoff - - if len(set("ACGT").intersection(set(ref))) == 0 or len(set("ACGT").intersection(set(alt))) == 0 \ - or (len(ref) != 1 and len(alt) != 1 and len(ref) != len(alt)): - print("[Line %s]" % lnum, "WARNING, skipping variant: Variant format not supported.") - return -1 - elif len(ref) > 2*d: - print("[Line %s]" % lnum, "WARNING, skipping variant: Deletion too large") - return -1 - - fasta = pyfastx.Fasta(args.reference_file) - # try to make vcf chromosomes compatible with reference chromosomes - if chr not in fasta.keys() and "chr"+chr in fasta.keys(): - chr = "chr"+chr - elif chr not in fasta.keys() and chr[3:] in fasta.keys(): - chr = chr[3:] - - try: - seq = fasta[chr][pos-5001-d:pos+len(ref)+4999+d].seq - except Exception as e: - print(e) - print("[Line %s]" % lnum, "WARNING, skipping variant: Could not get sequence, possibly because the variant is too close to chromosome ends. " - "See error message above.") - return -1 - - if seq[5000+d:5000+d+len(ref)] != ref: - print("[Line %s]" % lnum, "WARNING, skipping variant: Mismatch between FASTA (ref base: %s) and variant file (ref base: %s)." - % (seq[5000+d:5000+d+len(ref)], ref)) - return -1 - - ref_seq = seq - alt_seq = seq[:5000+d] + alt + seq[5000+d+len(ref):] - - # get genes that intersect variant - genes_pos, genes_neg = get_genes(chr, pos, gtf) - if len(genes_pos)+len(genes_neg)==0: - print("[Line %s]" % lnum, "WARNING, skipping variant: Variant not contained in a gene body. Do GTF/FASTA chromosome names match?") - return -1 - - # get splice scores - loss_pos, gain_pos = None, None - if len(genes_pos) > 0: - loss_pos, gain_pos = compute_score(ref_seq, alt_seq, '+', d, models) - loss_neg, gain_neg = None, None - if len(genes_neg) > 0: - loss_neg, gain_neg = compute_score(ref_seq, alt_seq, '-', d, models) - - scores = "" - for (genes, loss, gain) in \ - ((genes_pos,loss_pos,gain_pos),(genes_neg,loss_neg,gain_neg)): - for gene, positions in genes.items(): - warnings = "Warnings:" - - if args.mask == "True" and len(positions) != 0: - positions = np.array(positions) - positions = positions - (pos - d) - - positions_filt = positions[(positions>=0) & (positions=cutoff)[0] - for p, s in zip(np.concatenate([g-d,l-d]), np.concatenate([gain[g],loss[l]])): - scores += "%s:%s|" % (p, round(s,2)) - - else: - scores = scores+gene+'|' - l, g = np.argmin(loss), np.argmax(gain), - scores += "%s:%s|%s:%s|" % (g-d, round(gain[g],2), l-d, round(loss[l],2)) - - scores += warnings - - return scores.strip('|') - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("variant_file", help="VCF or CSV file with a header (see COLUMN_IDS option).") - parser.add_argument("reference_file", help="FASTA file containing a reference genome sequence.") - parser.add_argument("annotation_file", help="gffutils database file. Can be generated using create_db.py.") - parser.add_argument("output_file", help="Prefix for output file. Will be a VCF/CSV if variant_file is VCF/CSV.") - parser.add_argument("-c", "--column_ids", default="CHROM,POS,REF,ALT", help="(If variant_file is a CSV) Column IDs for: chromosome, variant position, reference bases, and alternative bases. " - "Separate IDs by commas. (Default: CHROM,POS,REF,ALT)") - parser.add_argument("-m", "--mask", default="True", choices=["False","True"], help="If True, splice gains (increases in score) at annotated splice sites and splice losses (decreases in score) at unannotated splice sites will be set to 0. (Default: True)") - parser.add_argument("-s", "--score_cutoff", type=float, help="Output all sites with absolute predicted change in score >= cutoff, instead of only the maximum loss/gain sites.") - parser.add_argument("-d", "--distance", type=int, default=50, help="Number of bases on either side of the variant for which splice scores should be calculated. (Default: 50)") - #parser.add_argument("--score_exons", default="False", choices=["False","True"], help="Output changes in score for both splice sites of annotated exons, as long as one splice site is within the considered range (specified by -d). Output will be: gene|site1_pos:score|site2_pos:score|...") - args = parser.parse_args() - - variants = args.variant_file - gtf = args.annotation_file - try: - gtf = gffutils.FeatureDB(gtf) - except: - print("ERROR, annotation_file could not be opened. Is it a gffutils database file?") - exit() - - if torch.cuda.is_available(): - print("Using GPU") - else: - print("Using CPU") - - models = [] - for i in [0,2,4,6]: - for j in range(1,4): - model = Pangolin(L, W, AR) - if torch.cuda.is_available(): - model.cuda() - weights = torch.load(resource_filename(__name__,"models/final.%s.%s.3.v2" % (j, i))) - else: - weights = torch.load(resource_filename(__name__,"models/final.%s.%s.3.v2" % (j, i)), map_location=torch.device('cpu')) - model.load_state_dict(weights) - model.eval() - models.append(model) - - if variants.endswith(".vcf"): - lnum = 0 - # count the number of header lines - for line in open(variants, 'r'): - lnum += 1 - if line[0] != '#': - break - - variants = vcf.Reader(filename=variants) - variants.infos["Pangolin"] = vcf.parser._Info( - "Pangolin",'.',"String","Pangolin splice scores. " - "Format: gene|pos:score_change|pos:score_change|...",'.','.') - fout = vcf.Writer(open(args.output_file+".vcf", 'w'), variants) - - for i, variant in enumerate(variants): - scores = process_variant(lnum+i, str(variant.CHROM), int(variant.POS), variant.REF, str(variant.ALT[0]), gtf, models, args) - if scores != -1: - variant.INFO["Pangolin"] = scores - fout.write_record(variant) - fout.flush() - - fout.close() - - elif variants.endswith(".csv"): - col_ids = args.column_ids.split(',') - variants = pd.read_csv(variants, header=0) - fout = open(args.output_file+".csv", 'w') - fout.write(','.join(variants.columns)+',Pangolin\n') - fout.flush() - - for lnum, variant in variants.iterrows(): - chr, pos, ref, alt = variant[col_ids] - ref, alt = ref.upper(), alt.upper() - scores = process_variant(lnum+1, str(chr), int(pos), ref, alt, gtf, models, args) - if scores == -1: - fout.write(','.join(variant.to_csv(header=False, index=False).split('\n'))+'\n') - else: - fout.write(','.join(variant.to_csv(header=False, index=False).split('\n'))+scores+'\n') - fout.flush() - - fout.close() - - else: - print("ERROR, variant_file needs to be a CSV or VCF.") - - # executionTime = (time.time() - startTime) - # print('Execution time in seconds: ' + str(executionTime)) - -if __name__ == '__main__': - main() diff --git a/pangolin/model.py b/pangolin/model.py index 11dfb43..e610d40 100755 --- a/pangolin/model.py +++ b/pangolin/model.py @@ -57,10 +57,13 @@ def __init__(self, L, W, AR): self.conv_last6 = nn.Conv1d(L, 1, 1) self.conv_last7 = nn.Conv1d(L, 2, 1) self.conv_last8 = nn.Conv1d(L, 1, 1) + self.W = torch.tensor(W, dtype=torch.float32) + self.AR = torch.tensor(AR, dtype=torch.float32) def forward(self, x): conv = self.conv1(x) skip = self.skip(conv) + j = 0 for i in range(len(W)): conv = self.resblocks[i](conv) @@ -68,6 +71,7 @@ def forward(self, x): dense = self.convs[j](conv) j += 1 skip = skip + dense + CL = 2 * np.sum(AR * (W - 1)) skip = F.pad(skip, (-CL // 2, -CL // 2)) out1 = F.softmax(self.conv_last1(skip), dim=1) @@ -79,5 +83,3 @@ def forward(self, x): out7 = F.softmax(self.conv_last7(skip), dim=1) out8 = torch.sigmoid(self.conv_last8(skip)) return torch.cat([out1, out2, out3, out4, out5, out6, out7, out8], 1) - - diff --git a/pangolin/pangolin.py b/pangolin/pangolin.py index 194736f..1bf7e7d 100755 --- a/pangolin/pangolin.py +++ b/pangolin/pangolin.py @@ -1,12 +1,26 @@ +#!/usr/bin/env python3 + import argparse -from pkg_resources import resource_filename -from pangolin.model import * -import vcf +import tempfile import gffutils import pandas as pd import pyfastx -# import time -# startTime = time.time() +import pickle +import time +import pysam +import os +import shelve +import sys +import logging +import shutil + +from pkg_resources import resource_filename + +#from concurrent.futures import ThreadPoolExecutor +from multiprocessing import Queue, Process + + +from pangolin.model import * IN_MAP = np.asarray([[0, 0, 0, 0], [1, 0, 0, 0], @@ -25,40 +39,70 @@ def one_hot_encode(seq, strand): seq = (5 - seq) % 5 # Reverse complement return IN_MAP[seq.astype('int8')] + -def compute_score(ref_seq, alt_seq, strand, d, models): - ref_seq = one_hot_encode(ref_seq, strand).T - ref_seq = torch.from_numpy(np.expand_dims(ref_seq, axis=0)).float() - alt_seq = one_hot_encode(alt_seq, strand).T - alt_seq = torch.from_numpy(np.expand_dims(alt_seq, axis=0)).float() - - if torch.cuda.is_available(): - ref_seq = ref_seq.to(torch.device("cuda")) - alt_seq = alt_seq.to(torch.device("cuda")) - - pangolin = [] - for j in range(4): - score = [] - for model in models[3*j:3*j+3]: - with torch.no_grad(): - ref = model(ref_seq)[0][[1,4,7,10][j],:].cpu().numpy() - alt = model(alt_seq)[0][[1,4,7,10][j],:].cpu().numpy() - if strand == '-': - ref = ref[::-1] - alt = alt[::-1] - l = 2*d+1 - ndiff = np.abs(len(ref)-len(alt)) - if len(ref)>len(alt): - alt = np.concatenate([alt[0:l//2+1],np.zeros(ndiff),alt[l//2+1:]]) - elif len(ref) alt.shape[-1]: + pad = torch.zeros(ref.shape[0], ndiff, device=device) + alt = torch.cat([alt[:, :l//2+1], pad, alt[:, l//2+1:]], dim=1) + elif ref.shape[-1] < alt.shape[-1]: + max_vals = alt[:, l//2:l//2+ndiff+1].max(dim=1, keepdim=True)[0] + alt = torch.cat([alt[:, :l//2], max_vals, alt[:, l//2+ndiff+1:]], dim=1) + + score.append(alt - ref) + + pangolin_scores.append(torch.stack(score).mean(dim=0)) + + # Stack scores along a new dimension + pangolin = torch.stack(pangolin_scores) + # Compute min and max scores across models (dim 0) + batch_loss, _ = pangolin.min(dim=0) + batch_gain, _ = pangolin.max(dim=0) + + # Move results to CPU and convert to lists of lists + batch_loss_cpu = batch_loss.cpu() #.numpy() # Shape: [batch_size, 101] + batch_gain_cpu = batch_gain.cpu() #.numpy() # Shape: [batch_size, 101] + + # keep cpu work as limited as possible here. reformat in the writer process + all_losses.append(batch_loss_cpu) + all_gains.append(batch_gain_cpu) - pangolin = np.array(pangolin) - loss = pangolin[np.argmin(pangolin, axis=0), np.arange(pangolin.shape[1])] - gain = pangolin[np.argmax(pangolin, axis=0), np.arange(pangolin.shape[1])] - return loss, gain + return all_losses, all_gains def get_genes(chr, pos, gtf): @@ -80,19 +124,20 @@ def get_genes(chr, pos, gtf): return (genes_pos, genes_neg) -def process_variant(lnum, chr, pos, ref, alt, gtf, models, args): - d = args.distance - cutoff = args.score_cutoff +def prepare_variant(lnum, chr, pos, ref, alt, gtf, args, fasta, batches, skipped_variants): + d = args.distance if len(set("ACGT").intersection(set(ref))) == 0 or len(set("ACGT").intersection(set(alt))) == 0 \ or (len(ref) != 1 and len(alt) != 1 and len(ref) != len(alt)): - print("[Line %s]" % lnum, "WARNING, skipping variant: Variant format not supported.") - return -1 + log.warning("[Line %s]" % lnum, " skipping variant: Variant format not supported.") + skipped_variants.add(f"{lnum}|{alt}") + return skipped_variants, batches + elif len(ref) > 2*d: - print("[Line %s]" % lnum, "WARNING, skipping variant: Deletion too large") - return -1 + log.warning(f"[Line {lnum}] skipping variant: Deletion too large") + skipped_variants.add(f"{lnum}|{alt}") + return skipped_variants, batches - fasta = pyfastx.Fasta(args.reference_file) # try to make vcf chromosomes compatible with reference chromosomes if chr not in fasta.keys() and "chr"+chr in fasta.keys(): chr = "chr"+chr @@ -100,17 +145,15 @@ def process_variant(lnum, chr, pos, ref, alt, gtf, models, args): chr = chr[3:] try: - seq = fasta[chr][pos-5001-d:pos+len(ref)+4999+d].seq - except Exception as e: - print(e) - print("[Line %s]" % lnum, "WARNING, skipping variant: Could not get sequence, possibly because the variant is too close to chromosome ends. " - "See error message above.") - return -1 + seq = fasta[chr][pos-5001-d:pos+len(ref)+4999+d].seq.upper() + except Exception: + skipped_variants.add(f"{lnum}|{alt}") + return skipped_variants, batches if seq[5000+d:5000+d+len(ref)] != ref: - print("[Line %s]" % lnum, "WARNING, skipping variant: Mismatch between FASTA (ref base: %s) and variant file (ref base: %s)." - % (seq[5000+d:5000+d+len(ref)], ref)) - return -1 + log.warning(f"[Line {lnum}] skipping variant: Mismatch between FASTA (ref base: {seq[5000+d:5000+d+len(ref)]}) and variant file (ref base: {ref}).") + skipped_variants.add(f"{lnum}|{alt}") + return skipped_variants, batches ref_seq = seq alt_seq = seq[:5000+d] + alt + seq[5000+d+len(ref):] @@ -118,79 +161,399 @@ def process_variant(lnum, chr, pos, ref, alt, gtf, models, args): # get genes that intersect variant genes_pos, genes_neg = get_genes(chr, pos, gtf) if len(genes_pos)+len(genes_neg)==0: - print("[Line %s]" % lnum, "WARNING, skipping variant: Variant not contained in a gene body. Do GTF/FASTA chromosome names match?") - return -1 - - # get splice scores - loss_pos, gain_pos = None, None + # no genes is not critical : keep on debug + log.debug(f"[Line {lnum}] skipping variant {chr}:{pos} {ref}/{alt}: Variant not contained in a gene body. Do GTF/FASTA chromosome names match?") + skipped_variants.add(f"{lnum}|{alt}") + return skipped_variants, batches + + # encode if len(genes_pos) > 0: - loss_pos, gain_pos = compute_score(ref_seq, alt_seq, '+', d, models) - loss_neg, gain_neg = None, None + ref_seq_pos = one_hot_encode(ref_seq, '+').T + ref_seq_pos = torch.from_numpy(np.expand_dims(ref_seq_pos, axis=0)).float() + alt_seq_pos = one_hot_encode(alt_seq, '+').T + alt_seq_pos = torch.from_numpy(np.expand_dims(alt_seq_pos, axis=0)).float() + + else: + ref_seq_pos, alt_seq_pos = None, None if len(genes_neg) > 0: - loss_neg, gain_neg = compute_score(ref_seq, alt_seq, '-', d, models) - - scores_list = [] - for (genes, loss, gain) in ( - (genes_pos,loss_pos,gain_pos),(genes_neg,loss_neg,gain_neg) - ): - # Emit a bundle of scores/warnings per gene; join them all later - for gene, positions in genes.items(): - per_gene_scores = [] - warnings = "Warnings:" - positions = np.array(positions) - positions = positions - (pos - d) - - if args.mask == "True" and len(positions) != 0: - positions_filt = positions[(positions>=0) & (positions Loading batch {batch_nr}") + genes, variants_skipped, variant_scores = read_batches(batch_nr, tmpdir) + for variant_key in variants_skipped: + # key is variant without +/- + sh[f"{variant_key}|+"] = None + sh[f"{variant_key}|-"] = None + for idx in range(0,len(genes[0])): + # key + if genes[0][idx] not in sh: + sh[genes[0][idx]] = {'genes': genes[1][idx]} + else: + d = sh[genes[0][idx]] + d['genes'] = genes[1][idx] + sh[genes[0][idx]] = d + for variant_key in variant_scores: + # should match genes, but make sure + if variant_key not in sh: + sh[variant_key] = {'loss': variant_scores[variant_key][0], 'gain': variant_scores[variant_key][1]} + else: + d = sh[variant_key] + d.update({'loss': variant_scores[variant_key][0], 'gain': variant_scores[variant_key][1]}) + sh[variant_key] = d + # write to disk + sh.sync() + + +def vcf_writer(queue, variants, args, tmpdir): # pos, ref_seq, alt_seq, genes_pos, genes_neg, models, args): + d = args.distance + cutoff = args.score_cutoff + # variants are out of order (based on tensor size) + # 1. create a shelve + try: + fill_shelve(tmpdir, queue) + except Exception as e: + log.error(f"Shelve creation failed: {repr(e)}") + sys.exit(1) + + # 2. once all are ready => write VCF + with pysam.VariantFile(variants) as variant_file, pysam.VariantFile( + args.output_file+".vcf", "w", header=variant_file.header + ) as out_variant_file, shelve.open(f"{tmpdir}/variants.shelve") as sh: + out_variant_file.header.add_meta( + key="INFO", + items=[ + ("ID", "Pangolin"), + ("Number", "."), + ("Type", "String"), + ( + "Description", + "Pangolin splice scores. Format: gene|pos:score_change|pos:score_change|warnings,..." + ), + ] + ) + + lnum = 0 + # count the number of header lines + for line in open(variants, 'r'): + lnum += 1 + if line[0] != '#': + break + + # start processing: + for idx, variant_record in enumerate(variant_file): + variant_record.translate(out_variant_file.header) + alt = str(variant_record.alts[0]) + pos = int(variant_record.pos) + variant_key = f"{lnum+idx}|{alt}" - for i in range(len(positions)//2): - p1, p2 = positions[2*i], positions[2*i+1] - if p1<0 or p1>=len(loss): - s1 = "NA" - else: - s1 = [loss[p1],gain[p1]] - s1 = round(s1[np.argmax(np.abs(s1))],2) - if p2<0 or p2>=len(loss): - s2 = "NA" + # skipped variant + if f"{variant_key}|+" not in sh and f"{variant_key}|-" not in sh: + out_variant_file.write(variant_record) + continue + # get the scores + if f"{variant_key}|+" in sh and sh[f"{variant_key}|+"] is not None: + # get the scores + loss_pos, gain_pos = sh[f"{variant_key}|+"]['loss'], sh[f"{variant_key}|+"]['gain'] + genes_pos = sh[f"{variant_key}|+"]['genes'] + else: + loss_pos, gain_pos, genes_pos = None, None, None + if f"{variant_key}|-" in sh and sh[f"{variant_key}|-"] is not None: + loss_neg, gain_neg = sh[f"{variant_key}|-"]['loss'], sh[f"{variant_key}|-"]['gain'] + genes_neg = sh[f"{variant_key}|-"]['genes'] + else: + loss_neg, gain_neg, genes_neg = None, None, None + + # reformat for vcf + scores_list = [] + for (genes, loss, gain) in ( + (genes_pos,loss_pos,gain_pos),(genes_neg,loss_neg,gain_neg) + ): + if loss is None or gain is None or len(genes) == 0: + continue + # Emit a bundle of scores/warnings per gene; join them all later + for gene, positions in genes.items(): + per_gene_scores = [] + warnings = "Warnings:" + positions = np.array(positions) + positions = positions - (pos - d) + # apply masking + if args.mask == "True" and len(positions) != 0: + positions_filt = positions[(positions>=0) & (positions=len(loss): + s1 = "NA" + else: + s1 = [loss[p1],gain[p1]] + s1 = round(s1[np.argmax(np.abs(s1))],2) + if p2<0 or p2>=len(loss): + s2 = "NA" + else: + s2 = [loss[p2],gain[p2]] + s2 = round(s2[np.argmax(np.abs(s2))],2) + if s1 == "NA" and s2 == "NA": + continue + scores1.append(f"{p1-d}:{s1}") + scores2.append(f"{p2-d}:{s2}") + per_gene_scores += scores1 + scores2 + + elif cutoff != None: + per_gene_scores.append(gene) + l, g = np.where(loss<=-cutoff)[0], np.where(gain>=cutoff)[0] + for p, s in zip(np.concatenate([g-d,l-d]), np.concatenate([gain[g],loss[l]])): + per_gene_scores.append(f"{p}:{round(s,2)}") + else: - s2 = [loss[p2],gain[p2]] - s2 = round(s2[np.argmax(np.abs(s2))],2) - if s1 == "NA" and s2 == "NA": - continue - scores1.append(f"{p1-d}:{s1}") - scores2.append(f"{p2-d}:{s2}") - per_gene_scores += scores1 + scores2 - - elif cutoff != None: - per_gene_scores.append(gene) - l, g = np.where(loss<=-cutoff)[0], np.where(gain>=cutoff)[0] - for p, s in zip(np.concatenate([g-d,l-d]), np.concatenate([gain[g],loss[l]])): - per_gene_scores.append(f"{p}:{round(s,2)}") + per_gene_scores.append(gene) + l, g = np.argmin(loss), np.argmax(gain), + gain_str = f"{g-d}:{round(gain[g],2)}" + loss_str = f"{l-d}:{round(loss[l],2)}" + per_gene_scores += [gain_str, loss_str] + + per_gene_scores.append(warnings) + scores_list.append('|'.join(per_gene_scores)) + + # write to vcf + variant_record.info["Pangolin"] = ",".join(scores_list) + out_variant_file.write(variant_record) + + # remove the shelve + #os.remove(f"{tmpdir}/variants.shelve") + +def pickle_batches(batches, skipped_variants, batch_nr, tmpdir, queue, args, all=False): + for length in batches: + # can be equal or +1 to variant_batchsize + if all == True or len(batches[length]['strands']) >= args.variant_batchsize: + # genes info + with open(f"{tmpdir}/batch_{batch_nr}.genes.pickle", 'wb') as f: + pickle.dump((batches[length]['variant_keys'], batches[length]['genes']), f) + # skipped variants so far + with open(f"{tmpdir}/batch_{batch_nr}.skipped_in.pickle", 'wb') as f: + pickle.dump(skipped_variants, f) + # data to run scoring + with open(f"{tmpdir}/batch_{batch_nr}.scores_in.pickle", 'wb') as f: + pickle.dump((batches[length]['ref_seqs'], batches[length]['alt_seqs'], batches[length]['strands'], batches[length]['variant_keys']), f) + # reset for next batch + batches[length] = {'ref_seqs' : [], 'alt_seqs' : [], 'strands' : [], 'variant_keys' : [], 'genes' : []} + skipped_variants = set() + # add to queue + queue.put(batch_nr) + # increase batch nr + batch_nr += 1 + return batch_nr, batches, skipped_variants + +def vcf_reader(variants, queue, gtf, args, tmpdir ): + # open fasta file + fasta = pyfastx.Fasta(args.reference_file) + # batch nr for tracking order + batch_nr = 0 + + lnum = 0 + # count the number of header lines + for line in open(variants, 'r'): + lnum += 1 + if line[0] != '#': + break + + #line numbers + skipped_variants = set() + # group variants for scoring by tensor size + batches = dict() + with pysam.VariantFile(variants) as variant_file: + for i, variant_record in enumerate(variant_file): + # validate variant + assert variant_record.ref, f"Empty REF field in variant record {variant_record}" + assert variant_record.alts, f"Empty ALT field in variant record {variant_record}" + skipped_variants, batches = prepare_variant( + lnum + i, + str(variant_record.contig), + int(variant_record.pos), + str(variant_record.ref).upper(), + str(variant_record.alts[0]).upper(), + gtf, + args, + fasta, + batches, + skipped_variants + ) + # pickle + batch_nr, batches, skipped_variants = pickle_batches(batches, skipped_variants, batch_nr, tmpdir, queue, args) + + # pickle remaining batches + batch_nr, batches, skipped_variants = pickle_batches(batches, skipped_variants, batch_nr, tmpdir, queue, args, all=True) + + # send the sentinel value to indicate the end of the queue + queue.put(None) + + +def csv_reader(variants, queue, gtf, args, tmpdir): + # open fasta file + fasta = pyfastx.Fasta(args.reference_file) + # batch nr for tracking order + batch_nr = 0 + col_ids = args.column_ids.split(',') + variants = pd.read_csv(variants, header=0) + + #line numbers + skipped_variants = set() + # group variants for scoring by tensor size + batches = dict() + + for lnum, variant in variants.iterrows(): + chr, pos, ref, alt = variant[col_ids] + ref, alt = ref.upper(), alt.upper() + # validate variant + assert ref, f"Empty REF field in variant record {lnum}" + assert alt, f"Empty ALT field in variant record {lnum}" + skipped_variants, batches = prepare_variant( + lnum , + str(chr), + int(pos), + ref.upper(), + alt.upper(), + gtf, + args, + fasta, + batches, + skipped_variants + ) + batch_nr, batches, skipped_variants = pickle_batches(batches, skipped_variants, batch_nr, tmpdir, queue, args) + + # pickle all remaining batches + batch_nr, batches, skipped_variants = pickle_batches(batches, skipped_variants, batch_nr, tmpdir, queue, args, all=True) + + # send the sentinel value to indicate the end of the queue + queue.put(None) + + +# minimal function to load the models and score the variant batches +def scoring(scoring_queue, writing_queue, args, tmpdir): + # load models + d = args.distance + models = [] + for i in [0,2,4,6]: + for j in range(1,4): + model = Pangolin(L, W, AR) + if torch.cuda.is_available(): + model.cuda() + weights = torch.load(resource_filename(__name__,"models/final.%s.%s.3.v2" % (j, i))) else: - per_gene_scores.append(gene) - l, g = np.argmin(loss), np.argmax(gain), - gain_str = f"{g-d}:{round(gain[g],2)}" - loss_str = f"{l-d}:{round(loss[l],2)}" - per_gene_scores += [gain_str, loss_str] + weights = torch.load(resource_filename(__name__,"models/final.%s.%s.3.v2" % (j, i)), map_location=torch.device('cpu')) + model.load_state_dict(weights) + model.eval() + models.append(model) + + # process the queue + while True: + item = scoring_queue.get() + if item is None: + log.debug("found sentinel on scoring queue. Close worker") + break + with open(f"{tmpdir}/batch_{item}.scores_in.pickle", 'rb') as f: + ref_seqs, alt_seqs, strands, variant_keys = pickle.load(f) + + # score batch variants + batch_time_start = time.time() + losses, gains = compute_scores_batch(item, ref_seqs, alt_seqs, strands, d, models, batch_size=args.tensor_batchsize) + + # pickle + batch_time_end = time.time() + print(f"Scored {len(ref_seqs)} variants in {int(batch_time_end - batch_time_start)} seconds : {int(len(ref_seqs)/((batch_time_end-batch_time_start)/3600))} variants/hour") + with open(f"{tmpdir}/batch_{item}.scores_out.pickle", 'wb') as f: + pickle.dump((variant_keys, losses, gains), f) + + writing_queue.put(item) + + # add end of queue signal + writing_queue.put(None) - per_gene_scores.append(warnings) - scores_list.append('|'.join(per_gene_scores)) - return ','.join(scores_list) def main(): parser = argparse.ArgumentParser() @@ -204,81 +567,86 @@ def main(): parser.add_argument("-s", "--score_cutoff", type=float, help="Output all sites with absolute predicted change in score >= cutoff, instead of only the maximum loss/gain sites.") parser.add_argument("-d", "--distance", type=int, default=50, help="Number of bases on either side of the variant for which splice scores should be calculated. (Default: 50)") parser.add_argument("--score_exons", default="False", choices=["False","True"], help="Output changes in score for both splice sites of annotated exons, as long as one splice site is within the considered range (specified by -d). Output will be: gene|site1_pos:score|site2_pos:score|...") + parser.add_argument("--loglevel", default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], help="Set the logging level. (Default: INFO)") + parser.add_argument("--tmpdir", help="Location to create temporary directory for storing intermediate files.", default=tempfile.gettempdir()) + parser.add_argument("--variant_batchsize", type=int, default=1000, help="Number of variants to score in a single CPU batch. (Default: 1000)") + parser.add_argument("--tensor_batchsize", type=int, default=128, help="Number of variants to process in a single GPU batch. (Default: 128)") args = parser.parse_args() + # logging + logging.basicConfig(level=args.loglevel, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') + global log + log = logging.getLogger(__name__) + + # start time: + log.info("Starting Pangolin") + startTime = time.time() + variants = args.variant_file gtf = args.annotation_file try: gtf = gffutils.FeatureDB(gtf) - except: - print("ERROR, annotation_file could not be opened. Is it a gffutils database file?") - exit() + except Exception as e: + log.error(f"Annotation_file could not be opened ({repr(e)}). Is it a gffutils database file?") + sys.exit(1) if torch.cuda.is_available(): - print("Using GPU") + log.info("Using GPU") else: - print("Using CPU") - - models = [] - for i in [0,2,4,6]: - for j in range(1,4): - model = Pangolin(L, W, AR) - if torch.cuda.is_available(): - model.cuda() - weights = torch.load(resource_filename(__name__,"models/final.%s.%s.3.v2" % (j, i))) - else: - weights = torch.load(resource_filename(__name__,"models/final.%s.%s.3.v2" % (j, i)), map_location=torch.device('cpu')) - model.load_state_dict(weights) - model.eval() - models.append(model) - + log.info("Using CPU") + + # tmp dir to pickle results + tmpdir = tempfile.mkdtemp(dir=args.tmpdir) + log.debug(f"Temporary directory for storing intermediate files: {tmpdir}") + # create queues: + scoring_queue = Queue(maxsize=25) + writing_queue = Queue(maxsize=25) + # create variant reader process: if variants.endswith(".vcf"): - lnum = 0 - # count the number of header lines - for line in open(variants, 'r'): - lnum += 1 - if line[0] != '#': - break + # process 1 : vcf => batch-pickles + reader = Process(target=vcf_reader, kwargs={"queue" : scoring_queue, "variants": variants, "gtf": gtf, "args": args, 'tmpdir': tmpdir}) + reader.start() + # process 3 : read batches into shelve => write VCF + writer = Process(target=vcf_writer, kwargs={"queue" : writing_queue, "variants": variants, "args": args, 'tmpdir': tmpdir}) + writer.start() - variants = vcf.Reader(filename=variants) - variants.infos["Pangolin"] = vcf.parser._Info( - "Pangolin",'.',"String","Pangolin splice scores. " - "Format: gene|pos:score_change|pos:score_change|warnings,...",'.','.') - fout = vcf.Writer(open(args.output_file+".vcf", 'w'), variants) + elif variants.endswith(".csv"): + reader = Process(target=csv_reader, kwargs={"queue" : scoring_queue, "variants": variants, "gtf": gtf, "args": args, 'tmpdir': tmpdir}) + reader.start() + # process 3 : read batches into shelve => write VCF + writer = Process(target=csv_writer, kwargs={"queue" : writing_queue, "variants": variants, "args": args, 'tmpdir': tmpdir}) + writer.start() + + else: + log.error("Variant_file needs to be a CSV or VCF.") - for i, variant in enumerate(variants): - scores = process_variant(lnum+i, str(variant.CHROM), int(variant.POS), variant.REF, str(variant.ALT[0]), gtf, models, args) - if scores != -1: - variant.INFO["Pangolin"] = scores - fout.write_record(variant) - fout.flush() + # score the variants (pickle => gpu => pickle) + # todo : if this can be subprocessed, we can use multiple gpus. (torch.multiprocessing ?) + try: + scoring(scoring_queue=scoring_queue,writing_queue=writing_queue,args=args,tmpdir=tmpdir) + except Exception as e: + log.error(f"Scoring process failed: {repr(e)}") + sys.exit(1) - fout.close() + # join the reader + reader.join() + # close first queue + scoring_queue.close() - elif variants.endswith(".csv"): - col_ids = args.column_ids.split(',') - variants = pd.read_csv(variants, header=0) - fout = open(args.output_file+".csv", 'w') - fout.write(','.join(variants.columns)+',Pangolin\n') - fout.flush() - - for lnum, variant in variants.iterrows(): - chr, pos, ref, alt = variant[col_ids] - ref, alt = ref.upper(), alt.upper() - scores = process_variant(lnum+1, str(chr), int(pos), ref, alt, gtf, models, args) - if scores == -1: - fout.write(','.join(variant.to_csv(header=False, index=False).split('\n'))+'\n') - else: - fout.write(','.join(variant.to_csv(header=False, index=False).split('\n'))+scores+'\n') - fout.flush() + # join the writer + writer.join() + # close second queue + writing_queue.close() - fout.close() - else: - print("ERROR, variant_file needs to be a CSV or VCF.") + executionTime = (time.time() - startTime) + # format exec time in hours:minutes:sec + log.info(f"Execution time: {int(executionTime//3600)}h:{int((executionTime%3600)//60)}m:{int(executionTime%60)}s") - # executionTime = (time.time() - startTime) - # print('Execution time in seconds: ' + str(executionTime)) + # remove temp folder and contents + shutil.rmtree(tmpdir) + if __name__ == '__main__': main() From c957f036ba31a7f83bef21f84e429c17b27606a9 Mon Sep 17 00:00:00 2001 From: Geert Vandeweyer Date: Thu, 30 Jan 2025 16:43:13 +0100 Subject: [PATCH 02/11] added the csv_writer --- pangolin/pangolin.py | 175 +++++++++++++++++++++++++++---------------- 1 file changed, 109 insertions(+), 66 deletions(-) diff --git a/pangolin/pangolin.py b/pangolin/pangolin.py index 1bf7e7d..6db1c93 100755 --- a/pangolin/pangolin.py +++ b/pangolin/pangolin.py @@ -275,7 +275,68 @@ def fill_shelve(tmpdir, queue): sh.sync() -def vcf_writer(queue, variants, args, tmpdir): # pos, ref_seq, alt_seq, genes_pos, genes_neg, models, args): +def format_scores(loss_pos, gain_pos, genes_pos, loss_neg, gain_neg, genes_neg, d, args, cutoff): + scores_list = [] + for (genes, loss, gain) in ( + (genes_pos,loss_pos,gain_pos),(genes_neg,loss_neg,gain_neg) + ): + if loss is None or gain is None or len(genes) == 0: + continue + # Emit a bundle of scores/warnings per gene; join them all later + for gene, positions in genes.items(): + per_gene_scores = [] + warnings = "Warnings:" + positions = np.array(positions) + positions = positions - (pos - d) + # apply masking + if args.mask == "True" and len(positions) != 0: + positions_filt = positions[(positions>=0) & (positions=len(loss): + s1 = "NA" + else: + s1 = [loss[p1],gain[p1]] + s1 = round(s1[np.argmax(np.abs(s1))],2) + if p2<0 or p2>=len(loss): + s2 = "NA" + else: + s2 = [loss[p2],gain[p2]] + s2 = round(s2[np.argmax(np.abs(s2))],2) + if s1 == "NA" and s2 == "NA": + continue + scores1.append(f"{p1-d}:{s1}") + scores2.append(f"{p2-d}:{s2}") + per_gene_scores += scores1 + scores2 + elif cutoff != None: + per_gene_scores.append(gene) + l, g = np.where(loss<=-cutoff)[0], np.where(gain>=cutoff)[0] + for p, s in zip(np.concatenate([g-d,l-d]), np.concatenate([gain[g],loss[l]])): + per_gene_scores.append(f"{p}:{round(s,2)}") + else: + per_gene_scores.append(gene) + l, g = np.argmin(loss), np.argmax(gain), + gain_str = f"{g-d}:{round(gain[g],2)}" + loss_str = f"{l-d}:{round(loss[l],2)}" + per_gene_scores += [gain_str, loss_str] + per_gene_scores.append(warnings) + scores_list.append('|'.join(per_gene_scores)) + return ",".join(scores_list) + + +def vcf_writer(queue, variants, args, tmpdir): d = args.distance cutoff = args.score_cutoff # variants are out of order (based on tensor size) @@ -335,75 +396,57 @@ def vcf_writer(queue, variants, args, tmpdir): # pos, ref_seq, alt_seq, genes_po loss_neg, gain_neg, genes_neg = None, None, None # reformat for vcf - scores_list = [] - for (genes, loss, gain) in ( - (genes_pos,loss_pos,gain_pos),(genes_neg,loss_neg,gain_neg) - ): - if loss is None or gain is None or len(genes) == 0: - continue - # Emit a bundle of scores/warnings per gene; join them all later - for gene, positions in genes.items(): - per_gene_scores = [] - warnings = "Warnings:" - positions = np.array(positions) - positions = positions - (pos - d) - # apply masking - if args.mask == "True" and len(positions) != 0: - positions_filt = positions[(positions>=0) & (positions=len(loss): - s1 = "NA" - else: - s1 = [loss[p1],gain[p1]] - s1 = round(s1[np.argmax(np.abs(s1))],2) - if p2<0 or p2>=len(loss): - s2 = "NA" - else: - s2 = [loss[p2],gain[p2]] - s2 = round(s2[np.argmax(np.abs(s2))],2) - if s1 == "NA" and s2 == "NA": - continue - scores1.append(f"{p1-d}:{s1}") - scores2.append(f"{p2-d}:{s2}") - per_gene_scores += scores1 + scores2 - - elif cutoff != None: - per_gene_scores.append(gene) - l, g = np.where(loss<=-cutoff)[0], np.where(gain>=cutoff)[0] - for p, s in zip(np.concatenate([g-d,l-d]), np.concatenate([gain[g],loss[l]])): - per_gene_scores.append(f"{p}:{round(s,2)}") - - else: - per_gene_scores.append(gene) - l, g = np.argmin(loss), np.argmax(gain), - gain_str = f"{g-d}:{round(gain[g],2)}" - loss_str = f"{l-d}:{round(loss[l],2)}" - per_gene_scores += [gain_str, loss_str] - - per_gene_scores.append(warnings) - scores_list.append('|'.join(per_gene_scores)) + scores = format_scores(loss_pos, gain_pos, genes_pos, loss_neg, gain_neg, genes_neg, d, args, cutoff) # write to vcf - variant_record.info["Pangolin"] = ",".join(scores_list) + variant_record.info["Pangolin"] = scores out_variant_file.write(variant_record) - # remove the shelve - #os.remove(f"{tmpdir}/variants.shelve") +def csv_writer(queue, variants, args, tmpdir): + d = args.distance + cutoff = args.score_cutoff + # variants are out of order (based on tensor size) + # 1. create a shelve + try: + fill_shelve(tmpdir, queue) + except Exception as e: + log.error(f"Shelve creation failed: {repr(e)}") + sys.exit(1) + + # 2. once all are ready => write CSV + col_ids = args.column_ids.split(',') + with pd.read_csv(variants, header=0) as variant_file, open(args.output_file+".csv", 'w') as fout, shelve.open(f"{tmpdir}/variants.shelve") as sh: + #variants = pd.read_csv(variants, header=0) + fout = open(args.output_file+".csv", 'w') + fout.write(','.join(variant_file.columns)+',Pangolin\n') + fout.flush() + for lnum, variant in variant_file.iterrows(): + chr, pos, ref, alt = variant[col_ids] + ref, alt = ref.upper(), alt.upper() + variant_key = f"{lnum}|{alt}" + # skipped variant + if f"{variant_key}|+" not in sh and f"{variant_key}|-" not in sh: + fout.write(','.join(variant.to_csv(header=False, index=False).split('\n'))+'\n') + continue + # get the scores + if f"{variant_key}|+" in sh and sh[f"{variant_key}|+"] is not None: + # get the scores + loss_pos, gain_pos = sh[f"{variant_key}|+"]['loss'], sh[f"{variant_key}|+"]['gain'] + genes_pos = sh[f"{variant_key}|+"]['genes'] + else: + loss_pos, gain_pos, genes_pos = None, None, None + if f"{variant_key}|-" in sh and sh[f"{variant_key}|-"] is not None: + loss_neg, gain_neg = sh[f"{variant_key}|-"]['loss'], sh[f"{variant_key}|-"]['gain'] + genes_neg = sh[f"{variant_key}|-"]['genes'] + else: + loss_neg, gain_neg, genes_neg = None, None, None + + scores = format_scores(loss_pos, gain_pos, genes_pos, loss_neg, gain_neg, genes_neg, d, args, cutoff) + + fout.write(','.join(variant.to_csv(header=False, index=False).split('\n'))+scores+'\n') + fout.flush() + + fout.close() def pickle_batches(batches, skipped_variants, batch_nr, tmpdir, queue, args, all=False): for length in batches: From cbff8ea782bd591a055bb324290c21e576b432d5 Mon Sep 17 00:00:00 2001 From: Geert Vandeweyer Date: Fri, 31 Jan 2025 09:02:12 +0100 Subject: [PATCH 03/11] updated readme and default batchsize --- README.md | 17 ++++++++++++++--- pangolin/pangolin.py | 2 +- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 912ddb9..cd3b479 100755 --- a/README.md +++ b/README.md @@ -52,15 +52,17 @@ See below for information on usage and local installation. ``` See full options below: ``` - usage: pangolin [-h] [-c COLUMN_IDS] [-m {False,True}] [-s SCORE_CUTOFF] [-d DISTANCE] variant_file reference_file annotation_file output_file + usage: pangolin [-h] [-c COLUMN_IDS] [-m {False,True}] [-s SCORE_CUTOFF] [-d DISTANCE] [--score_exons {False,True}] [--loglevel {DEBUG,INFO,WARNING,ERROR,CRITICAL}] [--tmpdir TMPDIR] [--variant_batchsize VARIANT_BATCHSIZE] + [--tensor_batchsize TENSOR_BATCHSIZE] + variant_file reference_file annotation_file output_file positional arguments: variant_file VCF or CSV file with a header (see COLUMN_IDS option). reference_file FASTA file containing a reference genome sequence. annotation_file gffutils database file. Can be generated using create_db.py. output_file Prefix for output file. Will be a VCF/CSV if variant_file is VCF/CSV. - - optional arguments: + + options: -h, --help show this help message and exit -c COLUMN_IDS, --column_ids COLUMN_IDS (If variant_file is a CSV) Column IDs for: chromosome, variant position, reference bases, and alternative bases. Separate IDs by commas. (Default: CHROM,POS,REF,ALT) @@ -70,6 +72,15 @@ See below for information on usage and local installation. Output all sites with absolute predicted change in score >= cutoff, instead of only the maximum loss/gain sites. -d DISTANCE, --distance DISTANCE Number of bases on either side of the variant for which splice scores should be calculated. (Default: 50) + --score_exons {False,True} + Output changes in score for both splice sites of annotated exons, as long as one splice site is within the considered range (specified by -d). Output will be: gene|site1_pos:score|site2_pos:score|... + --loglevel {DEBUG,INFO,WARNING,ERROR,CRITICAL} + Set the logging level. (Default: INFO) + --tmpdir TMPDIR Location to create temporary directory for storing intermediate files. + --variant_batchsize VARIANT_BATCHSIZE + Number of variants to score in a single CPU batch. (Default: 1280) + --tensor_batchsize TENSOR_BATCHSIZE + Number of variants to process in a single GPU batch. (Default: 128) ``` ### Usage (custom) diff --git a/pangolin/pangolin.py b/pangolin/pangolin.py index 6db1c93..71bfb06 100755 --- a/pangolin/pangolin.py +++ b/pangolin/pangolin.py @@ -612,7 +612,7 @@ def main(): parser.add_argument("--score_exons", default="False", choices=["False","True"], help="Output changes in score for both splice sites of annotated exons, as long as one splice site is within the considered range (specified by -d). Output will be: gene|site1_pos:score|site2_pos:score|...") parser.add_argument("--loglevel", default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], help="Set the logging level. (Default: INFO)") parser.add_argument("--tmpdir", help="Location to create temporary directory for storing intermediate files.", default=tempfile.gettempdir()) - parser.add_argument("--variant_batchsize", type=int, default=1000, help="Number of variants to score in a single CPU batch. (Default: 1000)") + parser.add_argument("--variant_batchsize", type=int, default=1280, help="Number of variants to score in a single CPU batch. (Default: 1280)") parser.add_argument("--tensor_batchsize", type=int, default=128, help="Number of variants to process in a single GPU batch. (Default: 128)") args = parser.parse_args() From 7a9b31336e43c15eba7e4c51e179cf789fac6a35 Mon Sep 17 00:00:00 2001 From: Geert Vandeweyer Date: Fri, 31 Jan 2025 11:58:22 +0100 Subject: [PATCH 04/11] add dockerfile based on cuda 12.0.0 --- docker/Dockerfile | 64 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 docker/Dockerfile diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 0000000..58bcf9e --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,64 @@ +###################################### +## CONTAINER FOR GPU based pangolin ## +###################################### + +# start from the cuda docker base +from nvidia/cuda:12.0.0-runtime-ubuntu22.04 + +## needed apt packages +ARG BUILD_PACKAGES="wget git bzip2" +# needed conda packages +ARG CONDA_PACKAGES="python==3.10.8 pip==25.0 pandas==2.2.3 pyfastx==0.8.4 gffutils==0.13 pytorch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 pytorch-cuda==12.4 pysam==0.20.0" +ARG CONDA_CHANNEL="-c nvidia -c pytorch -c conda-forge -c anaconda -c bioconda" +## ENV SETTINGS during runtime +ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 +ENV PATH=/opt/conda/bin:$PATH +ENV DEBIAN_FRONTEND noninteractive + +## AUTHOR +ENV AUTHOR="Geert Vandeweyer" +ENV EMAIL="geert.vandeweyer@uza.be" + +# For micromamba: +SHELL ["/bin/bash", "-l", "-c"] +ENV MAMBA_ROOT_PREFIX=/opt/conda/ +ENV PATH=/opt/micromamba/bin:/opt/conda/bin:$PATH + + +## INSTALL +RUN apt-get -y update && \ + apt-get -y install $BUILD_PACKAGES && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + + +# conda packages +RUN mkdir /opt/conda && \ + mkdir /opt/micromamba && \ + wget -qO - https://micromamba.snakepit.net/api/micromamba/linux-64/0.23.0 | tar -xvj -C /opt/micromamba bin/micromamba && \ + # initialize bash + micromamba shell init --shell=bash --prefix=/opt/conda && \ + # remove a statement from bashrc that prevents initialization + grep -v '[ -z "\$PS1" ] && return' /root/.bashrc > /opt/micromamba/bashrc && \ + mv /opt/micromamba/bashrc /root/.bashrc && \ + source ~/.bashrc && \ + # activate & install base conda packag + micromamba activate && \ + micromamba install -y $CONDA_CHANNEL $CONDA_PACKAGES && \ + micromamba clean --all --yes + +# Break cache for recloning git +ENV DATE_CACHE_BREAK=$(date) + +# my fork of pangolin : has gpu optimizations +RUN cd /opt/ && \ + git clone https://github.com/geertvandeweyer/pangolin.git && \ + cd pangolin && \ + pip install . + +# ADD annotation data + +## EXAMPLE: +# docker run --rm --gpus all -it -v /home:/home spliceai:1.3 nvidia-smi + +#spliceai -I /home/geert/TestVCF/wessd-228741-i.haplotypecaller.final.vcf -O /home/geert/TestVCF/wessd-228741-i.haplotypecaller.final.out.gpu.vcf -R /home/geert/refGenome/hg19.fasta -A grch37 -M 1 -B 32 \ No newline at end of file From dddee3e575e6f0161553bbc16fdbc2a03f22acc3 Mon Sep 17 00:00:00 2001 From: Geert Vandeweyer Date: Sat, 1 Feb 2025 10:23:09 +0100 Subject: [PATCH 05/11] fixed bug in format_scores --- pangolin/pangolin.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pangolin/pangolin.py b/pangolin/pangolin.py index 71bfb06..9ce29d2 100755 --- a/pangolin/pangolin.py +++ b/pangolin/pangolin.py @@ -275,7 +275,7 @@ def fill_shelve(tmpdir, queue): sh.sync() -def format_scores(loss_pos, gain_pos, genes_pos, loss_neg, gain_neg, genes_neg, d, args, cutoff): +def format_scores(position, loss_pos, gain_pos, genes_pos, loss_neg, gain_neg, genes_neg, d, args, cutoff): scores_list = [] for (genes, loss, gain) in ( (genes_pos,loss_pos,gain_pos),(genes_neg,loss_neg,gain_neg) @@ -287,7 +287,7 @@ def format_scores(loss_pos, gain_pos, genes_pos, loss_neg, gain_neg, genes_neg, per_gene_scores = [] warnings = "Warnings:" positions = np.array(positions) - positions = positions - (pos - d) + positions = positions - (position - d) # apply masking if args.mask == "True" and len(positions) != 0: positions_filt = positions[(positions>=0) & (positions Date: Sat, 1 Feb 2025 10:37:56 +0100 Subject: [PATCH 06/11] correct csv opening in csv_writer --- pangolin/pangolin.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pangolin/pangolin.py b/pangolin/pangolin.py index 9ce29d2..34248c7 100755 --- a/pangolin/pangolin.py +++ b/pangolin/pangolin.py @@ -415,8 +415,9 @@ def csv_writer(queue, variants, args, tmpdir): # 2. once all are ready => write CSV col_ids = args.column_ids.split(',') - with pd.read_csv(variants, header=0) as variant_file, open(args.output_file+".csv", 'w') as fout, shelve.open(f"{tmpdir}/variants.shelve") as sh: - #variants = pd.read_csv(variants, header=0) + with open(args.output_file+".csv", 'w') as fout, shelve.open(f"{tmpdir}/variants.shelve") as sh: + # this reads the whole file in memory. + variant_file = pd.read_csv(variants, header=0) fout = open(args.output_file+".csv", 'w') fout.write(','.join(variant_file.columns)+',Pangolin\n') fout.flush() From 7db24b94b209f1389cf8a3694bdc0e3809137adf Mon Sep 17 00:00:00 2001 From: Geert Vandeweyer Date: Sat, 1 Feb 2025 11:00:07 +0100 Subject: [PATCH 07/11] print to log --- pangolin/pangolin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pangolin/pangolin.py b/pangolin/pangolin.py index 34248c7..63c2151 100755 --- a/pangolin/pangolin.py +++ b/pangolin/pangolin.py @@ -588,7 +588,7 @@ def scoring(scoring_queue, writing_queue, args, tmpdir): # pickle batch_time_end = time.time() - print(f"Scored {len(ref_seqs)} variants in {int(batch_time_end - batch_time_start)} seconds : {int(len(ref_seqs)/((batch_time_end-batch_time_start)/3600))} variants/hour") + log.info(f"Scored {len(ref_seqs)} variants in {int(batch_time_end - batch_time_start)} seconds : {int(len(ref_seqs)/((batch_time_end-batch_time_start)/3600))} variants/hour") with open(f"{tmpdir}/batch_{item}.scores_out.pickle", 'wb') as f: pickle.dump((variant_keys, losses, gains), f) From c94e6046432cf4804a793ba3ce9586856dac6368 Mon Sep 17 00:00:00 2001 From: Geert Vandeweyer Date: Sat, 1 Feb 2025 12:48:17 +0100 Subject: [PATCH 08/11] match original output without pangolin field if no scores returned --- pangolin/pangolin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pangolin/pangolin.py b/pangolin/pangolin.py index 63c2151..5251406 100755 --- a/pangolin/pangolin.py +++ b/pangolin/pangolin.py @@ -397,9 +397,9 @@ def vcf_writer(queue, variants, args, tmpdir): # reformat for vcf scores = format_scores(position, loss_pos, gain_pos, genes_pos, loss_neg, gain_neg, genes_neg, d, args, cutoff) - + if scores != "": + variant_record.info["Pangolin"] = scores # write to vcf - variant_record.info["Pangolin"] = scores out_variant_file.write(variant_record) def csv_writer(queue, variants, args, tmpdir): From 2800dd6bb5fc5b07b9bfafb3656c0b2b8fd5ab15 Mon Sep 17 00:00:00 2001 From: Geert Vandeweyer Date: Sun, 2 Feb 2025 17:43:43 +0100 Subject: [PATCH 09/11] make csv writer output identical if no scores availalble, add dockerfile --- docker/Dockerfile | 9 +-------- pangolin/pangolin.py | 6 ++++-- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 58bcf9e..30ad3d3 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -47,9 +47,6 @@ RUN mkdir /opt/conda && \ micromamba install -y $CONDA_CHANNEL $CONDA_PACKAGES && \ micromamba clean --all --yes -# Break cache for recloning git -ENV DATE_CACHE_BREAK=$(date) - # my fork of pangolin : has gpu optimizations RUN cd /opt/ && \ git clone https://github.com/geertvandeweyer/pangolin.git && \ @@ -57,8 +54,4 @@ RUN cd /opt/ && \ pip install . # ADD annotation data - -## EXAMPLE: -# docker run --rm --gpus all -it -v /home:/home spliceai:1.3 nvidia-smi - -#spliceai -I /home/geert/TestVCF/wessd-228741-i.haplotypecaller.final.vcf -O /home/geert/TestVCF/wessd-228741-i.haplotypecaller.final.out.gpu.vcf -R /home/geert/refGenome/hg19.fasta -A grch37 -M 1 -B 32 \ No newline at end of file +CMD ["pangolin", "--help"] diff --git a/pangolin/pangolin.py b/pangolin/pangolin.py index 5251406..97a2f18 100755 --- a/pangolin/pangolin.py +++ b/pangolin/pangolin.py @@ -443,8 +443,10 @@ def csv_writer(queue, variants, args, tmpdir): loss_neg, gain_neg, genes_neg = None, None, None scores = format_scores(position, loss_pos, gain_pos, genes_pos, loss_neg, gain_neg, genes_neg, d, args, cutoff) - - fout.write(','.join(variant.to_csv(header=False, index=False).split('\n'))+scores+'\n') + if scores != "": + fout.write(','.join(variant.to_csv(header=False, index=False).split('\n'))+scores+'\n') + else: + fout.write(','.join(variant.to_csv(header=False, index=False).split('\n'))+'\n') fout.flush() fout.close() From 81a16183f825162dd93f641d1d558ca900ac4de9 Mon Sep 17 00:00:00 2001 From: Geert Vandeweyer Date: Tue, 4 Nov 2025 13:58:46 +0100 Subject: [PATCH 10/11] correct handling of errors --- pangolin/pangolin.py | 39 +++++++++++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/pangolin/pangolin.py b/pangolin/pangolin.py index 97a2f18..78a5054 100755 --- a/pangolin/pangolin.py +++ b/pangolin/pangolin.py @@ -13,7 +13,7 @@ import sys import logging import shutil - +import traceback from pkg_resources import resource_filename #from concurrent.futures import ThreadPoolExecutor @@ -673,17 +673,36 @@ def main(): scoring(scoring_queue=scoring_queue,writing_queue=writing_queue,args=args,tmpdir=tmpdir) except Exception as e: log.error(f"Scoring process failed: {repr(e)}") + traceback.print_exc(file=sys.stderr) + + # Clean shutdown of processes + log.warning("Terminating processes...") + for process in (reader, writer): + if process.is_alive(): + process.terminate() + process.join(timeout=5) + if process.is_alive(): + process.kill() + + # Close queues + for queue in (scoring_queue, writing_queue): + queue.close() + + # Clean up temporary directory + try: + shutil.rmtree(tmpdir) + except Exception as cleanup_error: + log.warning(f"Could not remove temporary directory {tmpdir}: {repr(cleanup_error)}") + # exit with error sys.exit(1) - # join the reader - reader.join() - # close first queue - scoring_queue.close() - - # join the writer - writer.join() - # close second queue - writing_queue.close() + # join the processes + for process in (reader, writer): + process.join() + + # close queues + for queue in (scoring_queue, writing_queue): + queue.close() executionTime = (time.time() - startTime) From 4550fb0160151718cab916cc7d7e349b4e5cb142 Mon Sep 17 00:00:00 2001 From: Geert Vandeweyer Date: Tue, 4 Nov 2025 14:47:39 +0100 Subject: [PATCH 11/11] handle padding issues for short sequences --- pangolin/model.py | 47 +++++++++++++++++++++++++++++++++++++---------- 1 file changed, 37 insertions(+), 10 deletions(-) diff --git a/pangolin/model.py b/pangolin/model.py index e610d40..cce5194 100755 --- a/pangolin/model.py +++ b/pangolin/model.py @@ -73,13 +73,40 @@ def forward(self, x): skip = skip + dense CL = 2 * np.sum(AR * (W - 1)) - skip = F.pad(skip, (-CL // 2, -CL // 2)) - out1 = F.softmax(self.conv_last1(skip), dim=1) - out2 = torch.sigmoid(self.conv_last2(skip)) - out3 = F.softmax(self.conv_last3(skip), dim=1) - out4 = torch.sigmoid(self.conv_last4(skip)) - out5 = F.softmax(self.conv_last5(skip), dim=1) - out6 = torch.sigmoid(self.conv_last6(skip)) - out7 = F.softmax(self.conv_last7(skip), dim=1) - out8 = torch.sigmoid(self.conv_last8(skip)) - return torch.cat([out1, out2, out3, out4, out5, out6, out7, out8], 1) + trim_amount = CL // 2 + seq_length = skip.shape[-1] + batch_size = skip.shape[0] + + # Try vectorized processing first + try: + trimmed = F.pad(skip, (-trim_amount, -trim_amount)) + out1 = F.softmax(self.conv_last1(trimmed), dim=1) + out2 = torch.sigmoid(self.conv_last2(trimmed)) + out3 = F.softmax(self.conv_last3(trimmed), dim=1) + out4 = torch.sigmoid(self.conv_last4(trimmed)) + out5 = F.softmax(self.conv_last5(trimmed), dim=1) + out6 = torch.sigmoid(self.conv_last6(trimmed)) + out7 = F.softmax(self.conv_last7(trimmed), dim=1) + out8 = torch.sigmoid(self.conv_last8(trimmed)) + return torch.cat([out1, out2, out3, out4, out5, out6, out7, out8], 1) + except RuntimeError: + # Fallback: process each sample individually + outputs = [] + for batch_idx in range(batch_size): + sample = skip[batch_idx:batch_idx+1] + try: + trimmed = F.pad(sample, (-trim_amount, -trim_amount)) + out1 = F.softmax(self.conv_last1(trimmed), dim=1) + out2 = torch.sigmoid(self.conv_last2(trimmed)) + out3 = F.softmax(self.conv_last3(trimmed), dim=1) + out4 = torch.sigmoid(self.conv_last4(trimmed)) + out5 = F.softmax(self.conv_last5(trimmed), dim=1) + out6 = torch.sigmoid(self.conv_last6(trimmed)) + out7 = F.softmax(self.conv_last7(trimmed), dim=1) + out8 = torch.sigmoid(self.conv_last8(trimmed)) + out = torch.cat([out1, out2, out3, out4, out5, out6, out7, out8], 1) + except RuntimeError: + # Zero output for this sample if it fails + out = torch.zeros(1, 12, seq_length, dtype=skip.dtype, device=skip.device) + outputs.append(out) + return torch.cat(outputs, dim=0)