Skip to content

Commit 65f9b42

Browse files
committed
feat: get basic tests working
1 parent 403ece8 commit 65f9b42

File tree

5 files changed

+84
-30
lines changed

5 files changed

+84
-30
lines changed

divref/divref/haplotype.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def to_hashable_items(d: dict[str, _V]) -> tuple[tuple[str, _V], ...]:
2626
return tuple(sorted(d.items()))
2727

2828

29-
def get_haplo_sequence(context_size: int, variants: Any) -> Any:
29+
def get_haplo_sequence(context_size: int, variants: Any, reference_genome: str = "GRCh38") -> Any:
3030
"""
3131
Construct a haplotype sequence string with flanking genomic context.
3232
@@ -38,6 +38,7 @@ def get_haplo_sequence(context_size: int, variants: Any) -> Any:
3838
context_size: Number of reference bases to include flanking each end.
3939
variants: Hail array expression of variant structs with locus and alleles fields.
4040
Must contain at least one variant.
41+
reference_genome: Name of the reference genome. Defaults to "GRCh38".
4142
4243
Returns:
4344
Hail string expression representing the full haplotype sequence.
@@ -60,7 +61,7 @@ def get_haplo_sequence(context_size: int, variants: Any) -> Any:
6061
min_pos,
6162
before=context_size,
6263
after=(max_pos - min_pos + max_variant_size + context_size - 1),
63-
reference_genome="GRCh38",
64+
reference_genome=reference_genome,
6465
)
6566

6667
# (min_pos - index_translation) equals context_size, mapping locus positions to string indices
@@ -117,7 +118,7 @@ def split_haplotypes(ht: Any, window_size: int) -> Any:
117118
Hail table with haplotypes exploded into sub-haplotypes by window.
118119
"""
119120
breakpoints = hl.range(1, hl.len(ht.variants)).filter(
120-
lambda i: (i == 0) | (variant_distance(ht.variants[i - 1], ht.variants[i]) >= window_size)
121+
lambda i: variant_distance(ht.variants[i - 1], ht.variants[i]) >= window_size
121122
)
122123

123124
def get_range(i: Any) -> Any:

divref/tests/conftest.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""Shared pytest fixtures for the divref test suite."""
22

33
from collections.abc import Generator
4+
from pathlib import Path
5+
from typing import Any
46

57
import hail as hl
68
import pytest
@@ -18,3 +20,20 @@ def hail_context() -> Generator[None, None, None]:
1820
hl.init(quiet=True)
1921
yield
2022
hl.stop()
23+
24+
25+
@pytest.fixture
26+
def datadir() -> Path:
27+
return Path(__file__).parent / "data"
28+
29+
30+
@pytest.fixture
31+
def hail_reference_genome() -> Any:
32+
"""A small custom reference genome for use in testing."""
33+
contigs: list[str] = ["chr1"]
34+
lengths: dict[str, int] = {"chr1": 1000}
35+
36+
reference_genome = hl.ReferenceGenome(
37+
"test_chr1", contigs, lengths, x_contigs=[], y_contigs=[], mt_contigs=[]
38+
)
39+
return reference_genome

divref/tests/data/test.fa

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
> chr1
2+
AGCTGGTGTTCGGTTCGGTAACGGAGAATCTGTGGGGCTATGTCACTAATACTTTCGAAACGCCCCGTACCGATGCTGAACAAGTCGATGCAGGCTCCCG
3+
TCTTTGAATAGGGGTAAACATACAAGTCGATAGAAGATGGGTAGGGGCCTCCAATTCATCCAACACTCTACGCCTTCTCCAAGAGCTAGTAGGGCACCCT
4+
GCAGTTGGAAAGGGAACTATTTCGTAGGGCGAGCCCATACCGTCTCTCTTGCGGAAGACTTAACACGATAGGAAGCTGGAATAGTTTCGAACGATGGTTA
5+
TTAATCCTAATAACGGAACGCTGTCTGGAGGATGAGTGTGACGGAGTGTAACTCGATGAGTTACCCGCTAATCGAACTGGGCGAGAGATCCCAGCGCTGA
6+
TGCACTCGATCCCGAGGCCTGACCCGACATATCAGCTCAGACTAGAGCGGGGCTGTTGACGTTTGGGGTTGAAAAAATCTATTGTACCAATCGGCTTCAA
7+
CGTGCTCCACGGCTGGCGCCTGAGGAGGGGCCCACACCGAGGAAGTAGACTGTTGCACGTTGGCGATGGCGGTAGCTAACTAAGTCGCCTGCCACAACAA
8+
CAGTATCAAAGCCGTATAAAGGGAACATCCACACTTTAGTGAATCGAAGCGCGGCATCAGAATTTCCTTTTGGATACCTGATACAAAGCCCATCGTGGTC
9+
CTTAGACTTCGTGCACATACAGCTGCACCGCACGCATGTGGAATTAGAGGCGAAGTACGATTCCTAGACCGACGTACGATACAACTATGTGGATGTGACG
10+
AGCTTCTTTTATATGCTTCGCCCGCCGGACCGGCCTCGCGATGGCGTAGCTGCGCATAAGCAAATGACAATTAACCACTGTGTACTCGTTATAACATCTG
11+
GCAGTTAAAGTCGGGAGAATAGGAGCCGCAATACACAGTTTACCGCATCTAGACCTAACTGAGATACTGCCATAGACGACTAGCCATCCCTCTGGCTCTT

divref/tests/data/test.fa.fai

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
chr1 1000 7 100 101

divref/tests/test_haplotype.py

Lines changed: 49 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,70 @@
11
"""Tests for shared Hail utilities in haplotype.py."""
22

3+
from pathlib import Path
34
from typing import Any
45

6+
import hail as hl
57
import pytest
68

79
from divref.haplotype import get_haplo_sequence
810
from divref.haplotype import split_haplotypes
911
from divref.haplotype import to_hashable_items
1012
from divref.haplotype import variant_distance
1113

12-
hl = pytest.importorskip("hail")
14+
# ---------------------------------------------------------------------------
15+
# Helper functions
16+
# ---------------------------------------------------------------------------
17+
18+
19+
def _make_variant(position: int, ref: str, alt: str) -> Any:
20+
return hl.Struct(locus=hl.Struct(contig="chr1", position=position), alleles=[ref, alt])
21+
22+
23+
def _make_haplotype_table(variant_positions: list[tuple[int, str, str]]) -> Any:
24+
variant_type = hl.tstruct(locus=hl.tstruct(position=hl.tint32), alleles=hl.tarray(hl.tstr))
25+
row_type = hl.tstruct(
26+
variants=hl.tarray(variant_type),
27+
haplotype=hl.tarray(hl.tstr),
28+
gnomad_freqs=hl.tarray(hl.tfloat64),
29+
)
30+
variants = [
31+
{"locus": {"position": pos}, "alleles": [ref, alt]} for pos, ref, alt in variant_positions
32+
]
33+
return hl.Table.parallelize(
34+
[
35+
{
36+
"variants": variants,
37+
"haplotype": [str(i) for i in range(len(variants))],
38+
"gnomad_freqs": [0.1] * len(variants),
39+
}
40+
],
41+
schema=row_type,
42+
)
1343

1444

1545
# ---------------------------------------------------------------------------
1646
# get_haplo_sequence
1747
# ---------------------------------------------------------------------------
1848

1949

50+
def test_get_haplo_sequence_single(
51+
datadir: Path,
52+
hail_reference_genome: Any,
53+
hail_context: None, # noqa: ARG001
54+
) -> None:
55+
"""get_haplo_sequence should return the correct haplotype sequence."""
56+
test_fasta: Path = datadir / "test.fa"
57+
test_fai: Path = datadir / "test.fa.fai"
58+
59+
hail_reference_genome.add_sequence(str(test_fasta), str(test_fai))
60+
61+
variant = _make_variant(position=100, ref="A", alt="C")
62+
x = get_haplo_sequence(
63+
context_size=2, variants=[variant], reference_genome=hail_reference_genome.name
64+
)
65+
print(hl.str(x))
66+
67+
2068
def test_get_haplo_sequence_empty_list_raises() -> None:
2169
"""get_haplo_sequence should raise ValueError when given an empty list."""
2270
with pytest.raises(ValueError, match="at least one variant"):
@@ -51,10 +99,6 @@ def test_to_hashable_items_sorted_by_key() -> None:
5199
# ---------------------------------------------------------------------------
52100

53101

54-
def _make_variant(position: int, ref: str, alt: str) -> Any:
55-
return hl.Struct(locus=hl.Struct(position=position), alleles=[ref, alt])
56-
57-
58102
def test_variant_distance_adjacent_snps(hail_context: None) -> None: # noqa: ARG001
59103
# SNP at 100, next SNP at 101: distance = 101 - 100 - len("A") = 0
60104
assert (
@@ -81,28 +125,6 @@ def test_variant_distance_deletion_closes_gap(hail_context: None) -> None: # no
81125
# ---------------------------------------------------------------------------
82126

83127

84-
def _make_haplotype_table(variant_positions: list[tuple[int, str, str]]) -> Any:
85-
variant_type = hl.tstruct(locus=hl.tstruct(position=hl.tint32), alleles=hl.tarray(hl.tstr))
86-
row_type = hl.tstruct(
87-
variants=hl.tarray(variant_type),
88-
haplotype=hl.tarray(hl.tstr),
89-
gnomad_freqs=hl.tarray(hl.tfloat64),
90-
)
91-
variants = [
92-
{"locus": {"position": pos}, "alleles": [ref, alt]} for pos, ref, alt in variant_positions
93-
]
94-
return hl.Table.parallelize(
95-
[
96-
{
97-
"variants": variants,
98-
"haplotype": [str(i) for i in range(len(variants))],
99-
"gnomad_freqs": [0.1] * len(variants),
100-
}
101-
],
102-
schema=row_type,
103-
)
104-
105-
106128
def test_split_haplotypes_no_split_needed(hail_context: None) -> None: # noqa: ARG001
107129
# All variants within window_size=200; haplotype is kept intact as one row
108130
ht = _make_haplotype_table([(100, "A", "T"), (150, "C", "G"), (190, "G", "A")])

0 commit comments

Comments
 (0)