Skip to content

Commit b8eef18

Browse files
ameynertclaude
andcommitted
feat: add compute_haplotypes tool
Port compute_haplotypes.py from human-diversity-reference/scripts as a defopt-compatible toolkit tool. Reads VCF files, annotates variants with gnomAD population allele frequencies, and extracts phased haplotypes per population using two overlapping genomic windows, writing the union as a keyed Hail table. Extracts the inner get_haplotypes function to module-level _get_haplotypes and replaces typer.echo with logging. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent c41aafa commit b8eef18

1 file changed

Lines changed: 222 additions & 0 deletions

File tree

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
"""Tool to compute haplotypes from VCF files with gnomAD population frequency annotations."""
2+
3+
import logging
4+
from typing import Any
5+
from typing import Callable
6+
7+
import hail as hl
8+
9+
from divref.haplotype import HailPath
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
def _get_haplotypes(
15+
ht: Any,
16+
windower_f: Callable[[Any], Any],
17+
idx: int,
18+
output_base: HailPath,
19+
pop_ints: dict[str, int],
20+
) -> Any:
21+
"""
22+
Group variants into haplotypes within genomic windows and compute empirical frequencies.
23+
24+
Applies the windowing function to assign variants to windows, aggregates haplotypes per
25+
population and sample, filters to multi-variant haplotypes, and collapses across samples
26+
to compute empirical allele counts and frequencies. Writes intermediate results to a
27+
checkpoint table.
28+
29+
Args:
30+
ht: Hail table with per-variant population membership and frequency data.
31+
windower_f: Function mapping a Hail locus to the window locus key.
32+
idx: Index of this windowing pass (1 or 2), used in the checkpoint filename.
33+
output_base: Base path for output; checkpoint written to {output_base}.{idx}.ht.
34+
pop_ints: Mapping from population code to integer index.
35+
36+
Returns:
37+
Hail table of haplotypes with empirical frequency summaries.
38+
"""
39+
new_locus = windower_f(ht.locus)
40+
ht = ht.annotate(new_locus=new_locus)
41+
42+
def agg_haplos(arr: Any) -> Any:
43+
flat = hl.agg.explode(lambda elt: hl.agg.collect(elt.annotate(row_idx=ht.row_idx)), arr)
44+
pop_grouped = hl.group_by(lambda x: x.pop, flat)
45+
return pop_grouped.map_values(
46+
lambda arr_per_pop: hl.array(
47+
hl.array(hl.group_by(lambda inner_elt: inner_elt.sample, arr_per_pop))
48+
.filter(lambda sample_and_records: hl.len(sample_and_records[1]) > 1)
49+
.map(
50+
lambda sample_and_records: hl.sorted(
51+
sample_and_records[1].map(lambda e: e.row_idx)
52+
)
53+
)
54+
.group_by(lambda x: x)
55+
.map_values(lambda arr: hl.len(arr))
56+
)
57+
)
58+
59+
ht_grouped = ht.group_by("new_locus").aggregate(
60+
row_map=hl.dict(
61+
hl.agg.collect((
62+
ht.row_idx,
63+
ht.row.select("locus", "alleles", "freq", "frequencies_by_pop"),
64+
))
65+
),
66+
left_haplos=agg_haplos(ht.pops_and_ids_left),
67+
right_haplos=agg_haplos(ht.pops_and_ids_right),
68+
)
69+
70+
def collapse_haplos_across_samples(pop: Any, arr1: Any, arr2: Any) -> Any:
71+
# Assumes all AN == 2 * N_samples.
72+
flat = hl.array([arr1, arr2]).flatmap(lambda x: x.get(pop))
73+
74+
def map_haplo_group(t: Any) -> Any:
75+
haplotype = t[0]
76+
n_observed = hl.sum(t[1].map(lambda x: x[1]))
77+
component_variant_frequencies = haplotype.map(
78+
lambda x: ht_grouped.row_map[x].frequencies_by_pop[pop]
79+
)
80+
min_an = hl.min(component_variant_frequencies.map(lambda x: x.AN))
81+
return hl.struct(
82+
haplotype=haplotype,
83+
pop=pop,
84+
empirical_AC=n_observed,
85+
min_variant_frequency=hl.min(component_variant_frequencies.map(lambda x: x.AF[1])),
86+
empirical_AF=n_observed / min_an,
87+
)
88+
89+
return hl.array(hl.group_by(lambda x: x[0], flat)).map(map_haplo_group)
90+
91+
ht_grouped = ht_grouped.annotate(
92+
all_haplos=hl.literal(list(pop_ints.values())).flatmap(
93+
lambda pop: collapse_haplos_across_samples(
94+
pop, ht_grouped.left_haplos, ht_grouped.right_haplos
95+
)
96+
)
97+
)
98+
99+
def get_haplotype_summary(a: Any) -> dict[str, Any]:
100+
a_sorted = hl.sorted(a, key=lambda x: x.empirical_AF, reverse=True)
101+
return dict(
102+
max_pop=a_sorted[0].pop,
103+
max_empirical_AF=a_sorted[0].empirical_AF,
104+
max_empirical_AC=a_sorted[0].empirical_AC,
105+
min_variant_frequency=a_sorted[0].min_variant_frequency,
106+
all_pop_freqs=a_sorted.map(lambda x: x.drop("haplotype")),
107+
)
108+
109+
ht_grouped = ht_grouped.transmute(
110+
all_haplos=hl.array(hl.group_by(lambda x: x.haplotype, ht_grouped.all_haplos)).map(
111+
lambda t: hl.struct(haplotype=t[0], **get_haplotype_summary(t[1]))
112+
)
113+
)
114+
115+
hte = ht_grouped.explode("all_haplos")
116+
hte = hte.key_by().drop("new_locus")
117+
118+
def get_variant(row_idx: Any) -> Any:
119+
return hte.row_map[row_idx].select("locus", "alleles")
120+
121+
def get_gnomad_freq(row_idx: Any) -> Any:
122+
return hte.row_map[row_idx].freq
123+
124+
hte = hte.select(
125+
**hte.all_haplos,
126+
variants=hte.all_haplos.haplotype.map(get_variant),
127+
gnomad_freqs=hte.all_haplos.haplotype.map(get_gnomad_freq),
128+
)
129+
130+
hte = hte.group_by("haplotype").aggregate(
131+
**hl.sorted(
132+
hl.agg.collect(hte.row.drop("haplotype")),
133+
key=lambda row: -row.max_empirical_AF,
134+
)[0]
135+
)
136+
137+
logger.info("Writing %s.%s.ht ...", output_base, idx)
138+
return hte.checkpoint(f"{output_base}.{idx}.ht", overwrite=True)
139+
140+
141+
def compute_haplotypes(
142+
*,
143+
vcfs_path: HailPath,
144+
gnomad_va_file: HailPath,
145+
gnomad_sa_file: HailPath,
146+
window_size: int,
147+
freq_threshold: float,
148+
output_base: HailPath,
149+
temp_dir: HailPath = "/tmp",
150+
) -> None:
151+
"""
152+
Compute population haplotypes from VCF files with gnomAD frequency annotations.
153+
154+
Reads VCF files, annotates variants with gnomAD population allele frequencies,
155+
extracts phased haplotypes per population using two overlapping window strategies,
156+
and writes the union of both windowed results as a keyed Hail table.
157+
158+
Args:
159+
vcfs_path: Path or glob pattern to input VCF files (GRCh38).
160+
gnomad_va_file: Path to the gnomAD variant annotations Hail table
161+
(from extract_gnomad_afs).
162+
gnomad_sa_file: Path to the gnomAD sample metadata Hail table
163+
(from extract_gnomad_afs).
164+
window_size: Base window size in bp for grouping variants into haplotypes.
165+
freq_threshold: Minimum gnomAD population allele frequency to retain a variant.
166+
output_base: Base output path; writes {output_base}.1.ht, {output_base}.2.ht,
167+
and the final {output_base}.ht.
168+
temp_dir: Directory for Hail temporary files.
169+
"""
170+
hl.init(tmp_dir=temp_dir)
171+
172+
gnomad_sa = hl.read_table(gnomad_sa_file)
173+
gnomad_va = hl.read_table(gnomad_va_file)
174+
gnomad_va = gnomad_va.filter(hl.max(gnomad_va.pop_freqs.map(lambda x: x.AF)) >= freq_threshold)
175+
176+
mt = hl.import_vcf(vcfs_path, reference_genome="GRCh38", min_partitions=64)
177+
mt = mt.select_rows().select_cols()
178+
mt = mt.annotate_rows(freq=gnomad_va[mt.row_key].pop_freqs)
179+
mt = mt.filter_rows(hl.is_defined(mt.freq))
180+
181+
pop_legend: list[str] = gnomad_va.globals.pops.collect()[0]
182+
pop_ints = {pop: i for i, pop in enumerate(pop_legend)}
183+
mt = mt.annotate_cols(pop_int=hl.literal(pop_ints).get(gnomad_sa[mt.col_key].pop))
184+
mt = mt.filter_cols(hl.is_defined(mt.pop_int))
185+
mt = mt.add_row_index().add_col_index()
186+
mt = mt.filter_entries(mt.freq[mt.pop_int].AF >= freq_threshold)
187+
188+
mt = mt.annotate_rows(
189+
pops_and_ids_left=hl.agg.filter(
190+
mt.GT[0] != 0, hl.agg.collect(hl.struct(pop=mt.pop_int, sample=mt.col_idx))
191+
),
192+
pops_and_ids_right=hl.agg.filter(
193+
mt.GT[1] != 0, hl.agg.collect(hl.struct(pop=mt.pop_int, sample=mt.col_idx))
194+
),
195+
frequencies_by_pop=hl.agg.group_by(mt.pop_int, hl.agg.call_stats(mt.GT, 2)),
196+
)
197+
ht = mt.rows().select(
198+
"freq",
199+
"pops_and_ids_left",
200+
"pops_and_ids_right",
201+
"row_idx",
202+
"frequencies_by_pop",
203+
)
204+
205+
window1 = _get_haplotypes(
206+
ht,
207+
lambda locus: locus - (locus.position % window_size),
208+
1,
209+
output_base,
210+
pop_ints,
211+
)
212+
window2 = _get_haplotypes(
213+
ht,
214+
lambda locus: locus - ((locus.position + window_size // 2) % window_size),
215+
2,
216+
output_base,
217+
pop_ints,
218+
)
219+
220+
htu = window1.union(window2)
221+
logger.info("Writing final %s.ht ...", output_base)
222+
htu.key_by("haplotype").naive_coalesce(64).write(f"{output_base}.ht", overwrite=True)

0 commit comments

Comments
 (0)