From 8b0493e5346cd6553caa5e09117e7a168b6541a0 Mon Sep 17 00:00:00 2001 From: earx Date: Thu, 19 Jun 2025 09:49:18 +0200 Subject: [PATCH] make variant matrix 5x5 --- bin/stats/pluviometer.py | 6 +++--- bin/stats/site_variant_readers.py | 10 +++++----- bin/stats/test_pluviometer.py | 21 ++++++++++++++++++++- 3 files changed, 28 insertions(+), 9 deletions(-) diff --git a/bin/stats/pluviometer.py b/bin/stats/pluviometer.py index dbb15e9..e481309 100644 --- a/bin/stats/pluviometer.py +++ b/bin/stats/pluviometer.py @@ -102,7 +102,7 @@ class SiteFilter: def __init__(self, cov_threshold: int, edit_threshold: int) -> None: self.cov_threshold: int = cov_threshold self.edit_threshold: int = edit_threshold - self.frequencies: NDArray[np.int32] = np.zeros(4, np.int32) + self.frequencies: NDArray[np.int32] = np.zeros(5, np.int32) def apply(self, variant_data: SiteVariantData) -> None: if variant_data.coverage >= self.cov_threshold: @@ -127,9 +127,9 @@ def __init__(self, site_filter: SiteFilter) -> None: Rows and column indices correspond to bases in alphabetic order (ACGT) Row-columns corresponding to the same base (e.g. (0,0) -> (A,A)) do not represent edits, and should remain 0 """ - self.edit_read_freqs: NDArray[np.int32] = np.zeros((4, 4), dtype=np.int32) + self.edit_read_freqs: NDArray[np.int32] = np.zeros((5, 5), dtype=np.int32) - self.edit_site_freqs: NDArray[np.int32] = np.zeros((4, 4), dtype=np.int32) + self.edit_site_freqs: NDArray[np.int32] = np.zeros((5, 5), dtype=np.int32) self.filter = site_filter diff --git a/bin/stats/site_variant_readers.py b/bin/stats/site_variant_readers.py index 1a4f376..896c2a6 100644 --- a/bin/stats/site_variant_readers.py +++ b/bin/stats/site_variant_readers.py @@ -33,10 +33,10 @@ def __init__(self, strand: int, edit: str) -> None: self.position: int = 0 self.strand:int = strand assert edit in EDIT_TYPES + NONEDIT_TYPES - self.reference: int = NUC_STR_TO_IND[edit[0]] + self.reference: int = NUC_STR_TO_IND.get(edit[0], 4) self.edited: str = edit[1] - self.frequencies: NDArray[np.int32] = np.zeros(4, dtype=np.int32) - self.frequencies[NUC_STR_TO_IND[self.edited]] = 1 + self.frequencies: NDArray[np.int32] = np.zeros(5, dtype=np.int32) + self.frequencies[NUC_STR_TO_IND.get(self.edited, 4)] = 1 return None @@ -124,11 +124,11 @@ def _parse_parts(self) -> SiteVariantData: return SiteVariantData( seqid=self.parts[REDITOOLS_FIELD_INDEX["Seqid"]], position=int(self.parts[REDITOOLS_FIELD_INDEX["Position"]]) - 1, # Convert Reditools 1-based index to Python's 0-based index - reference=NUC_STR_TO_IND[reference_nuc_str], + reference=NUC_STR_TO_IND.get(reference_nuc_str, 4), strand=strand, coverage=int(self.parts[REDITOOLS_FIELD_INDEX["Coverage"]]), mean_quality=float(self.parts[REDITOOLS_FIELD_INDEX["MeanQ"]]), - frequencies=np.int32(self.parts[REDITOOLS_FIELD_INDEX["Frequencies"]][1:-1].split(",")) + frequencies=np.int32(self.parts[REDITOOLS_FIELD_INDEX["Frequencies"]][1:-1].split(",") + [0]) ) def read(self) -> Optional[SiteVariantData]: diff --git a/bin/stats/test_pluviometer.py b/bin/stats/test_pluviometer.py index 104da72..0e052d0 100644 --- a/bin/stats/test_pluviometer.py +++ b/bin/stats/test_pluviometer.py @@ -5,6 +5,8 @@ from Bio.SeqRecord import SeqRecord from site_variant_readers import TestReader from pluviometer import RecordManager +from pluviometer import SiteFilter +from pluviometer import write_output_file_header from contextlib import nullcontext from typing import Generator from utils import SiteVariantData @@ -32,6 +34,14 @@ def parse_cli_input() -> argparse.Namespace: type=str, help="Name of the output file (leave empty to write to stdout)", ) + parser.add_argument( + "--aggregation_mode", + "-a", + type=str, + default="all", + choices=["all", "cds_longest"], + help='Mode for aggregating counts: "all" aggregates features of every transcript; "cds_longest" aggregates features of the longest CDS or non-coding transcript', + ) return parser.parse_args() @@ -41,6 +51,10 @@ def parse_cli_input() -> argparse.Namespace: with (open(args.gff) as gff_handle, open(args.output, "w") if len(args.output) > 0 else nullcontext(sys.stdout) as output_handle): records: Generator[SeqRecord, None, None] = GFF.parse(gff_handle) + + global_filter: SiteFilter = SiteFilter(cov_threshold=1, edit_threshold=1) + + write_output_file_header(output_handle) for record in records: sv_reader = TestReader( @@ -49,5 +63,10 @@ def parse_cli_input() -> argparse.Namespace: ) sv_data: SiteVariantData = sv_reader.read() - manager: RecordManager = RecordManager(record, output_handle) + manager: RecordManager = RecordManager( + record, + global_filter, + output_handle, + args.aggregation_mode + ) manager.scan_and_count(sv_reader)