Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions divref/divref/haplotype.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def to_hashable_items(d: dict[str, _V]) -> tuple[tuple[str, _V], ...]:
return tuple(sorted(d.items()))


def get_haplo_sequence(context_size: int, variants: Any) -> Any:
def get_haplo_sequence(context_size: int, variants: Any, reference_genome: str = "GRCh38") -> Any:
"""
Construct a haplotype sequence string with flanking genomic context.

Expand All @@ -38,6 +38,7 @@ def get_haplo_sequence(context_size: int, variants: Any) -> Any:
context_size: Number of reference bases to include flanking each end.
variants: Hail array expression of variant structs with locus and alleles fields.
Must contain at least one variant.
reference_genome: Name of the reference genome. Defaults to "GRCh38".

Returns:
Hail string expression representing the full haplotype sequence.
Expand All @@ -60,7 +61,7 @@ def get_haplo_sequence(context_size: int, variants: Any) -> Any:
min_pos,
before=context_size,
after=(max_pos - min_pos + max_variant_size + context_size - 1),
reference_genome="GRCh38",
reference_genome=reference_genome,
)

# (min_pos - index_translation) equals context_size, mapping locus positions to string indices
Expand Down Expand Up @@ -100,7 +101,7 @@ def variant_distance(v1: Any, v2: Any) -> Any:
return v2.locus.position - v1.locus.position - hl.len(v1.alleles[0])


def split_haplotypes(ht: Any, window_size: int) -> Any:
def split_haplotypes(ht: hl.Table, window_size: int) -> hl.Table:
"""
Split multi-variant haplotypes at gaps of at least `window_size` bases.

Expand All @@ -117,7 +118,7 @@ def split_haplotypes(ht: Any, window_size: int) -> Any:
Hail table with haplotypes exploded into sub-haplotypes by window.
"""
breakpoints = hl.range(1, hl.len(ht.variants)).filter(
lambda i: (i == 0) | (variant_distance(ht.variants[i - 1], ht.variants[i]) >= window_size)
lambda i: variant_distance(ht.variants[i - 1], ht.variants[i]) >= window_size
)

def get_range(i: Any) -> Any:
Expand Down
19 changes: 19 additions & 0 deletions divref/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Shared pytest fixtures for the divref test suite."""

from collections.abc import Generator
from pathlib import Path

import hail as hl
import pytest
Expand All @@ -18,3 +19,21 @@ def hail_context() -> Generator[None, None, None]:
hl.init(quiet=True)
yield
hl.stop()


@pytest.fixture
def datadir() -> Path:
"""Path to tests/data."""
return Path(__file__).parent / "data"


@pytest.fixture
def hail_reference_genome(hail_context: None) -> hl.ReferenceGenome: # noqa: ARG001
"""A small custom reference genome for use in testing."""
contigs: list[str] = ["chr1"]
lengths: dict[str, int] = {"chr1": 1000}

reference_genome = hl.ReferenceGenome(
"test_chr1", contigs, lengths, x_contigs=[], y_contigs=[], mt_contigs=[]
)
return reference_genome
11 changes: 11 additions & 0 deletions divref/tests/data/test.fa
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
>chr1
AGCTGGTGTTCGGTTCGGTAACGGAGAATCTGTGGGGCTATGTCACTAATACTTTCGAAACGCCCCGTACCGATGCTGAACAAGTCGATGCAGGCTCCCG
TCTTTGAATAGGGGTAAACATACAAGTCGATAGAAGATGGGTAGGGGCCTCCAATTCATCCAACACTCTACGCCTTCTCCAAGAGCTAGTAGGGCACCCT
GCAGTTGGAAAGGGAACTATTTCGTAGGGCGAGCCCATACCGTCTCTCTTGCGGAAGACTTAACACGATAGGAAGCTGGAATAGTTTCGAACGATGGTTA
TTAATCCTAATAACGGAACGCTGTCTGGAGGATGAGTGTGACGGAGTGTAACTCGATGAGTTACCCGCTAATCGAACTGGGCGAGAGATCCCAGCGCTGA
TGCACTCGATCCCGAGGCCTGACCCGACATATCAGCTCAGACTAGAGCGGGGCTGTTGACGTTTGGGGTTGAAAAAATCTATTGTACCAATCGGCTTCAA
CGTGCTCCACGGCTGGCGCCTGAGGAGGGGCCCACACCGAGGAAGTAGACTGTTGCACGTTGGCGATGGCGGTAGCTAACTAAGTCGCCTGCCACAACAA
CAGTATCAAAGCCGTATAAAGGGAACATCCACACTTTAGTGAATCGAAGCGCGGCATCAGAATTTCCTTTTGGATACCTGATACAAAGCCCATCGTGGTC
CTTAGACTTCGTGCACATACAGCTGCACCGCACGCATGTGGAATTAGAGGCGAAGTACGATTCCTAGACCGACGTACGATACAACTATGTGGATGTGACG
AGCTTCTTTTATATGCTTCGCCCGCCGGACCGGCCTCGCGATGGCGTAGCTGCGCATAAGCAAATGACAATTAACCACTGTGTACTCGTTATAACATCTG
GCAGTTAAAGTCGGGAGAATAGGAGCCGCAATACACAGTTTACCGCATCTAGACCTAACTGAGATACTGCCATAGACGACTAGCCATCCCTCTGGCTCTT
1 change: 1 addition & 0 deletions divref/tests/data/test.fa.fai
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
chr1 1000 6 100 101
110 changes: 79 additions & 31 deletions divref/tests/test_haplotype.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,83 @@
"""Tests for shared Hail utilities in haplotype.py."""

from typing import Any
from pathlib import Path

import hail as hl
import pytest

from divref.haplotype import get_haplo_sequence
from divref.haplotype import split_haplotypes
from divref.haplotype import to_hashable_items
from divref.haplotype import variant_distance

hl = pytest.importorskip("hail")
# ---------------------------------------------------------------------------
# Helper functions
# ---------------------------------------------------------------------------


def _make_variant(position: int, ref: str, alt: str) -> hl.Struct:
return hl.Struct(locus=hl.Struct(contig="chr1", position=position), alleles=[ref, alt])


def _make_haplotype_table(variant_positions: list[tuple[str, int, str, str]]) -> hl.Table:
variant_type = hl.tstruct(
locus=hl.tstruct(contig=hl.tstr, position=hl.tint32), alleles=hl.tarray(hl.tstr)
)
row_type = hl.tstruct(
variants=hl.tarray(variant_type),
haplotype=hl.tarray(hl.tstr),
gnomad_freqs=hl.tarray(hl.tfloat64),
)
variants = [
{"locus": {"contig": contig, "position": pos}, "alleles": [ref, alt]}
for contig, pos, ref, alt in variant_positions
]
return hl.Table.parallelize(
[
{
"variants": variants,
"haplotype": [str(i) for i in range(len(variants))],
"gnomad_freqs": [0.1] * len(variants),
}
],
schema=row_type,
)


# ---------------------------------------------------------------------------
# get_haplo_sequence
# ---------------------------------------------------------------------------


def test_get_haplo_sequence_single(
datadir: Path,
hail_reference_genome: hl.ReferenceGenome,
hail_context: None, # noqa: ARG001
) -> None:
"""get_haplo_sequence should return the correct haplotype sequence."""
test_fasta: Path = datadir / "test.fa"
test_fai: Path = datadir / "test.fa.fai"

hail_reference_genome.add_sequence(str(test_fasta), str(test_fai))

variant: hl.Struct = _make_variant(position=100, ref="A", alt="C")
haplo_seq = get_haplo_sequence(
context_size=2, variants=[variant], reference_genome=hail_reference_genome.name
)
assert hl.eval(haplo_seq) == "CCCTC"


def test_get_haplo_sequence_invalid_reference_genome_raises(
hail_context: None, # noqa: ARG001
) -> None:
"""get_haplo_sequence should raise KeyError when given an unregistered reference genome."""
variant = _make_variant(position=100, ref="A", alt="C")
with pytest.raises(KeyError, match="nonexistent_genome"):
get_haplo_sequence(
context_size=2, variants=[variant], reference_genome="nonexistent_genome"
)


def test_get_haplo_sequence_empty_list_raises() -> None:
"""get_haplo_sequence should raise ValueError when given an empty list."""
with pytest.raises(ValueError, match="at least one variant"):
Expand Down Expand Up @@ -51,10 +112,6 @@ def test_to_hashable_items_sorted_by_key() -> None:
# ---------------------------------------------------------------------------


def _make_variant(position: int, ref: str, alt: str) -> Any:
return hl.Struct(locus=hl.Struct(position=position), alleles=[ref, alt])


def test_variant_distance_adjacent_snps(hail_context: None) -> None: # noqa: ARG001
# SNP at 100, next SNP at 101: distance = 101 - 100 - len("A") = 0
assert (
Expand All @@ -81,31 +138,13 @@ def test_variant_distance_deletion_closes_gap(hail_context: None) -> None: # no
# ---------------------------------------------------------------------------


def _make_haplotype_table(variant_positions: list[tuple[int, str, str]]) -> Any:
variant_type = hl.tstruct(locus=hl.tstruct(position=hl.tint32), alleles=hl.tarray(hl.tstr))
row_type = hl.tstruct(
variants=hl.tarray(variant_type),
haplotype=hl.tarray(hl.tstr),
gnomad_freqs=hl.tarray(hl.tfloat64),
)
variants = [
{"locus": {"position": pos}, "alleles": [ref, alt]} for pos, ref, alt in variant_positions
]
return hl.Table.parallelize(
[
{
"variants": variants,
"haplotype": [str(i) for i in range(len(variants))],
"gnomad_freqs": [0.1] * len(variants),
}
],
schema=row_type,
)


def test_split_haplotypes_no_split_needed(hail_context: None) -> None: # noqa: ARG001
# All variants within window_size=200; haplotype is kept intact as one row
ht = _make_haplotype_table([(100, "A", "T"), (150, "C", "G"), (190, "G", "A")])
ht = _make_haplotype_table([
("chr1", 100, "A", "T"),
("chr1", 150, "C", "G"),
("chr1", 190, "G", "A"),
])
rows = split_haplotypes(ht, window_size=200).collect()
assert len(rows) == 1
assert len(rows[0].variants) == 3
Expand All @@ -114,7 +153,12 @@ def test_split_haplotypes_no_split_needed(hail_context: None) -> None: # noqa:
def test_split_haplotypes_splits_at_large_gap(hail_context: None) -> None: # noqa: ARG001
# Gap between positions 101 and 500 (398 bases) exceeds window_size=200;
# results in two sub-haplotypes: [v0, v1] and [v2, v3]
ht = _make_haplotype_table([(100, "A", "T"), (101, "C", "G"), (500, "G", "A"), (501, "T", "C")])
ht = _make_haplotype_table([
("chr1", 100, "A", "T"),
("chr1", 101, "C", "G"),
("chr1", 500, "G", "A"),
("chr1", 501, "T", "C"),
])
rows = sorted(
split_haplotypes(ht, window_size=200).collect(),
key=lambda r: r.variants[0].locus.position,
Expand All @@ -127,7 +171,11 @@ def test_split_haplotypes_splits_at_large_gap(hail_context: None) -> None: # no
def test_split_haplotypes_discards_singleton_segment(hail_context: None) -> None: # noqa: ARG001
# Gap after position 100 isolates it as a singleton (discarded);
# only the two-variant segment [500, 501] is kept
ht = _make_haplotype_table([(100, "A", "T"), (500, "C", "G"), (501, "G", "A")])
ht = _make_haplotype_table([
("chr1", 100, "A", "T"),
("chr1", 500, "C", "G"),
("chr1", 501, "G", "A"),
])
rows = split_haplotypes(ht, window_size=200).collect()
assert len(rows) == 1
assert [v.locus.position for v in rows[0].variants] == [500, 501]
Loading