|
| 1 | +"""Tool to compute haplotypes from VCF files with gnomAD population frequency annotations.""" |
| 2 | + |
| 3 | +import logging |
| 4 | +from typing import Any |
| 5 | +from typing import Callable |
| 6 | + |
| 7 | +import hail as hl |
| 8 | + |
| 9 | +from divref.haplotype import HailPath |
| 10 | + |
| 11 | +logger = logging.getLogger(__name__) |
| 12 | + |
| 13 | + |
| 14 | +def _get_haplotypes( |
| 15 | + ht: Any, |
| 16 | + windower_f: Callable[[Any], Any], |
| 17 | + idx: int, |
| 18 | + output_base: HailPath, |
| 19 | + pop_ints: dict[str, int], |
| 20 | +) -> Any: |
| 21 | + """ |
| 22 | + Group variants into haplotypes within genomic windows and compute empirical frequencies. |
| 23 | +
|
| 24 | + Applies the windowing function to assign variants to windows, aggregates haplotypes per |
| 25 | + population and sample, filters to multi-variant haplotypes, and collapses across samples |
| 26 | + to compute empirical allele counts and frequencies. Writes intermediate results to a |
| 27 | + checkpoint table. |
| 28 | +
|
| 29 | + Args: |
| 30 | + ht: Hail table with per-variant population membership and frequency data. |
| 31 | + windower_f: Function mapping a Hail locus to the window locus key. |
| 32 | + idx: Index of this windowing pass (1 or 2), used in the checkpoint filename. |
| 33 | + output_base: Base path for output; checkpoint written to {output_base}.{idx}.ht. |
| 34 | + pop_ints: Mapping from population code to integer index. |
| 35 | +
|
| 36 | + Returns: |
| 37 | + Hail table of haplotypes with empirical frequency summaries. |
| 38 | + """ |
| 39 | + new_locus = windower_f(ht.locus) |
| 40 | + ht = ht.annotate(new_locus=new_locus) |
| 41 | + |
| 42 | + def agg_haplos(arr: Any) -> Any: |
| 43 | + flat = hl.agg.explode(lambda elt: hl.agg.collect(elt.annotate(row_idx=ht.row_idx)), arr) |
| 44 | + pop_grouped = hl.group_by(lambda x: x.pop, flat) |
| 45 | + return pop_grouped.map_values( |
| 46 | + lambda arr_per_pop: hl.array( |
| 47 | + hl.array(hl.group_by(lambda inner_elt: inner_elt.sample, arr_per_pop)) |
| 48 | + .filter(lambda sample_and_records: hl.len(sample_and_records[1]) > 1) |
| 49 | + .map( |
| 50 | + lambda sample_and_records: hl.sorted( |
| 51 | + sample_and_records[1].map(lambda e: e.row_idx) |
| 52 | + ) |
| 53 | + ) |
| 54 | + .group_by(lambda x: x) |
| 55 | + .map_values(lambda arr: hl.len(arr)) |
| 56 | + ) |
| 57 | + ) |
| 58 | + |
| 59 | + ht_grouped = ht.group_by("new_locus").aggregate( |
| 60 | + row_map=hl.dict( |
| 61 | + hl.agg.collect(( |
| 62 | + ht.row_idx, |
| 63 | + ht.row.select("locus", "alleles", "freq", "frequencies_by_pop"), |
| 64 | + )) |
| 65 | + ), |
| 66 | + left_haplos=agg_haplos(ht.pops_and_ids_left), |
| 67 | + right_haplos=agg_haplos(ht.pops_and_ids_right), |
| 68 | + ) |
| 69 | + |
| 70 | + def collapse_haplos_across_samples(pop: Any, arr1: Any, arr2: Any) -> Any: |
| 71 | + # Assumes all AN == 2 * N_samples. |
| 72 | + flat = hl.array([arr1, arr2]).flatmap(lambda x: x.get(pop)) |
| 73 | + |
| 74 | + def map_haplo_group(t: Any) -> Any: |
| 75 | + haplotype = t[0] |
| 76 | + n_observed = hl.sum(t[1].map(lambda x: x[1])) |
| 77 | + component_variant_frequencies = haplotype.map( |
| 78 | + lambda x: ht_grouped.row_map[x].frequencies_by_pop[pop] |
| 79 | + ) |
| 80 | + min_an = hl.min(component_variant_frequencies.map(lambda x: x.AN)) |
| 81 | + return hl.struct( |
| 82 | + haplotype=haplotype, |
| 83 | + pop=pop, |
| 84 | + empirical_AC=n_observed, |
| 85 | + min_variant_frequency=hl.min(component_variant_frequencies.map(lambda x: x.AF[1])), |
| 86 | + empirical_AF=n_observed / min_an, |
| 87 | + ) |
| 88 | + |
| 89 | + return hl.array(hl.group_by(lambda x: x[0], flat)).map(map_haplo_group) |
| 90 | + |
| 91 | + ht_grouped = ht_grouped.annotate( |
| 92 | + all_haplos=hl.literal(list(pop_ints.values())).flatmap( |
| 93 | + lambda pop: collapse_haplos_across_samples( |
| 94 | + pop, ht_grouped.left_haplos, ht_grouped.right_haplos |
| 95 | + ) |
| 96 | + ) |
| 97 | + ) |
| 98 | + |
| 99 | + def get_haplotype_summary(a: Any) -> dict[str, Any]: |
| 100 | + a_sorted = hl.sorted(a, key=lambda x: x.empirical_AF, reverse=True) |
| 101 | + return dict( |
| 102 | + max_pop=a_sorted[0].pop, |
| 103 | + max_empirical_AF=a_sorted[0].empirical_AF, |
| 104 | + max_empirical_AC=a_sorted[0].empirical_AC, |
| 105 | + min_variant_frequency=a_sorted[0].min_variant_frequency, |
| 106 | + all_pop_freqs=a_sorted.map(lambda x: x.drop("haplotype")), |
| 107 | + ) |
| 108 | + |
| 109 | + ht_grouped = ht_grouped.transmute( |
| 110 | + all_haplos=hl.array(hl.group_by(lambda x: x.haplotype, ht_grouped.all_haplos)).map( |
| 111 | + lambda t: hl.struct(haplotype=t[0], **get_haplotype_summary(t[1])) |
| 112 | + ) |
| 113 | + ) |
| 114 | + |
| 115 | + hte = ht_grouped.explode("all_haplos") |
| 116 | + hte = hte.key_by().drop("new_locus") |
| 117 | + |
| 118 | + def get_variant(row_idx: Any) -> Any: |
| 119 | + return hte.row_map[row_idx].select("locus", "alleles") |
| 120 | + |
| 121 | + def get_gnomad_freq(row_idx: Any) -> Any: |
| 122 | + return hte.row_map[row_idx].freq |
| 123 | + |
| 124 | + hte = hte.select( |
| 125 | + **hte.all_haplos, |
| 126 | + variants=hte.all_haplos.haplotype.map(get_variant), |
| 127 | + gnomad_freqs=hte.all_haplos.haplotype.map(get_gnomad_freq), |
| 128 | + ) |
| 129 | + |
| 130 | + hte = hte.group_by("haplotype").aggregate( |
| 131 | + **hl.sorted( |
| 132 | + hl.agg.collect(hte.row.drop("haplotype")), |
| 133 | + key=lambda row: -row.max_empirical_AF, |
| 134 | + )[0] |
| 135 | + ) |
| 136 | + |
| 137 | + logger.info("Writing %s.%s.ht ...", output_base, idx) |
| 138 | + return hte.checkpoint(f"{output_base}.{idx}.ht", overwrite=True) |
| 139 | + |
| 140 | + |
| 141 | +def compute_haplotypes( |
| 142 | + *, |
| 143 | + vcfs_path: HailPath, |
| 144 | + gnomad_va_file: HailPath, |
| 145 | + gnomad_sa_file: HailPath, |
| 146 | + window_size: int, |
| 147 | + freq_threshold: float, |
| 148 | + output_base: HailPath, |
| 149 | + temp_dir: HailPath = "/tmp", |
| 150 | +) -> None: |
| 151 | + """ |
| 152 | + Compute population haplotypes from VCF files with gnomAD frequency annotations. |
| 153 | +
|
| 154 | + Reads VCF files, annotates variants with gnomAD population allele frequencies, |
| 155 | + extracts phased haplotypes per population using two overlapping window strategies, |
| 156 | + and writes the union of both windowed results as a keyed Hail table. |
| 157 | +
|
| 158 | + Args: |
| 159 | + vcfs_path: Path or glob pattern to input VCF files (GRCh38). |
| 160 | + gnomad_va_file: Path to the gnomAD variant annotations Hail table |
| 161 | + (from extract_gnomad_afs). |
| 162 | + gnomad_sa_file: Path to the gnomAD sample metadata Hail table |
| 163 | + (from extract_gnomad_afs). |
| 164 | + window_size: Base window size in bp for grouping variants into haplotypes. |
| 165 | + freq_threshold: Minimum gnomAD population allele frequency to retain a variant. |
| 166 | + output_base: Base output path; writes {output_base}.1.ht, {output_base}.2.ht, |
| 167 | + and the final {output_base}.ht. |
| 168 | + temp_dir: Directory for Hail temporary files. |
| 169 | + """ |
| 170 | + hl.init(tmp_dir=temp_dir) |
| 171 | + |
| 172 | + gnomad_sa = hl.read_table(gnomad_sa_file) |
| 173 | + gnomad_va = hl.read_table(gnomad_va_file) |
| 174 | + gnomad_va = gnomad_va.filter(hl.max(gnomad_va.pop_freqs.map(lambda x: x.AF)) >= freq_threshold) |
| 175 | + |
| 176 | + mt = hl.import_vcf(vcfs_path, reference_genome="GRCh38", min_partitions=64) |
| 177 | + mt = mt.select_rows().select_cols() |
| 178 | + mt = mt.annotate_rows(freq=gnomad_va[mt.row_key].pop_freqs) |
| 179 | + mt = mt.filter_rows(hl.is_defined(mt.freq)) |
| 180 | + |
| 181 | + pop_legend: list[str] = gnomad_va.globals.pops.collect()[0] |
| 182 | + pop_ints = {pop: i for i, pop in enumerate(pop_legend)} |
| 183 | + mt = mt.annotate_cols(pop_int=hl.literal(pop_ints).get(gnomad_sa[mt.col_key].pop)) |
| 184 | + mt = mt.filter_cols(hl.is_defined(mt.pop_int)) |
| 185 | + mt = mt.add_row_index().add_col_index() |
| 186 | + mt = mt.filter_entries(mt.freq[mt.pop_int].AF >= freq_threshold) |
| 187 | + |
| 188 | + mt = mt.annotate_rows( |
| 189 | + pops_and_ids_left=hl.agg.filter( |
| 190 | + mt.GT[0] != 0, hl.agg.collect(hl.struct(pop=mt.pop_int, sample=mt.col_idx)) |
| 191 | + ), |
| 192 | + pops_and_ids_right=hl.agg.filter( |
| 193 | + mt.GT[1] != 0, hl.agg.collect(hl.struct(pop=mt.pop_int, sample=mt.col_idx)) |
| 194 | + ), |
| 195 | + frequencies_by_pop=hl.agg.group_by(mt.pop_int, hl.agg.call_stats(mt.GT, 2)), |
| 196 | + ) |
| 197 | + ht = mt.rows().select( |
| 198 | + "freq", |
| 199 | + "pops_and_ids_left", |
| 200 | + "pops_and_ids_right", |
| 201 | + "row_idx", |
| 202 | + "frequencies_by_pop", |
| 203 | + ) |
| 204 | + |
| 205 | + window1 = _get_haplotypes( |
| 206 | + ht, |
| 207 | + lambda locus: locus - (locus.position % window_size), |
| 208 | + 1, |
| 209 | + output_base, |
| 210 | + pop_ints, |
| 211 | + ) |
| 212 | + window2 = _get_haplotypes( |
| 213 | + ht, |
| 214 | + lambda locus: locus - ((locus.position + window_size // 2) % window_size), |
| 215 | + 2, |
| 216 | + output_base, |
| 217 | + pop_ints, |
| 218 | + ) |
| 219 | + |
| 220 | + htu = window1.union(window2) |
| 221 | + logger.info("Writing final %s.ht ...", output_base) |
| 222 | + htu.key_by("haplotype").naive_coalesce(64).write(f"{output_base}.ht", overwrite=True) |
0 commit comments