diff --git a/bin/FeatureOutputWriter.py b/bin/FeatureOutputWriter.py deleted file mode 100644 index 75867d1..0000000 --- a/bin/FeatureOutputWriter.py +++ /dev/null @@ -1,193 +0,0 @@ -from MultiCounter import MultiCounter -from Bio.SeqFeature import SeqFeature -from Bio.SeqFeature import ExactPosition -from collections import defaultdict -from typing import TextIO -from utils import BASE_TYPES, MATCH_MISMATCH_TYPES - -FEATURE_OUTPUT_FIELDS = [ - "SeqID", - "Parents", - "FeatureID", - "Type", - "Start", - "End", - "Strand", - "CoveredSites", - f"GenomeBases[{','.join(BASE_TYPES)}]", - f"SiteBasePairs[{','.join(MATCH_MISMATCH_TYPES)}]", - f"ReadBasePairs[{','.join(MATCH_MISMATCH_TYPES)}]", -] - -FEATURE_METADATA_OUTPUT_FIELDS = [ - "SeqID", - "ParentsIDs", - "FeatureID", - "Type", - "Start", - "End", - "Strand", -] - -FEATURE_DATA_OUTPUT_FIELDS = [ - "CoveredSites", - f"GenomeBases[{','.join(BASE_TYPES)}]", - f"SiteBasePairs[{','.join(MATCH_MISMATCH_TYPES)}]", - f"ReadBasePairs[{','.join(MATCH_MISMATCH_TYPES)}]", -] - -AGGREGATE_METADATA_OUTPUT_FIELDS = [ - "SeqID", - "ParentsIDs", - "FeatureID", - "ParentType", - "AggregateType", -] - -AGGREGATE_DATA_OUTPUT_FIELDS = [ - "CoveredSites", - f"GenomeBases[{','.join(BASE_TYPES)}]", - f"SiteBasePairs[{','.join(MATCH_MISMATCH_TYPES)}]", - f"ReadBasePairs[{','.join(MATCH_MISMATCH_TYPES)}]", -] - -STR_ZERO_BASE_FREQS = ",".join('0' for _ in range(len(BASE_TYPES))) -STR_ZERO_EDIT_FREQS = ",".join('0' for _ in range(len(MATCH_MISMATCH_TYPES))) - - -def make_parent_path(parent_list: list[str]) -> str: - """ - Create a path string from an ordered list of parent IDs. - The separator is a comma, chosen because it is one of the few invalid characters in tag=value entries of the attributes field in the GFF3 format. - - Consult the GFF3 specification for details: https://github.com/The-Sequence-Ontology/Specifications/blob/master/gff3.md - """ - return ','.join(parent_list) - - -class RainFileWriter: - def __init__( - self, handle: TextIO, metadata_fields: list[str], data_fields: list[str] - ): - self.handle = handle - self.metadata_fields: list[str] = metadata_fields - self.n_metadata: int = len(self.metadata_fields) - self.data_fields: list[str] = data_fields - self.n_data: int = len(self.data_fields) - - return None - - def write_header(self) -> int: - b: int = self.handle.write("\t".join(self.metadata_fields)) - b += self.handle.write("\t") - b += self.handle.write("\t".join(self.data_fields)) - b += self.handle.write("\n") - - return b - - def write_comment(self, comment: str) -> int: - b = self.handle.write("# ") - b += self.handle.write(comment) - b += self.handle.write("\n") - - return b - - def write_metadata(self, *metadata_values) -> int: - b: int = 0 - for val in metadata_values: - b += self.handle.write(val) - b += self.handle.write("\t") - - return b - - def write_data(self, *data_values) -> int: - b: int = 0 - for val in data_values[:-1]: - b += self.handle.write(val) - b += self.handle.write("\t") - - b += self.handle.write(data_values[-1]) - b += self.handle.write("\n") - - return b - - -class FeatureFileWriter(RainFileWriter): - def __init__(self, handle: TextIO): - super().__init__( - handle, FEATURE_METADATA_OUTPUT_FIELDS, FEATURE_DATA_OUTPUT_FIELDS - ) - - return None - - def write_metadata(self, record_id: str, feature: SeqFeature) -> int: - return super().write_metadata( - record_id, - make_parent_path(feature.parent_list), - feature.id, - feature.type, - str(feature.location.parts[0].start + ExactPosition(1)), - str(feature.location.parts[-1].end), - str(feature.location.strand), - ) - - def write_row_with_data( - self, record_id: str, feature: SeqFeature, counter: MultiCounter - ) -> int: - return self.write_metadata(record_id, feature) + self.write_data( - str(counter.genome_base_freqs.sum()), - ",".join(map(str, counter.genome_base_freqs.flat)), - ",".join(map(str, counter.edit_site_freqs.flat)), - ",".join(map(str, counter.edit_read_freqs.flat)), - ) - - def write_row_without_data(self, record_id: str, feature: SeqFeature) -> int: - return self.write_metadata(record_id, feature) + self.write_data( - '0', STR_ZERO_BASE_FREQS, STR_ZERO_EDIT_FREQS, STR_ZERO_EDIT_FREQS - ) - -class AggregateFileWriter(RainFileWriter): - def __init__(self, handle: TextIO): - super().__init__( - handle, AGGREGATE_METADATA_OUTPUT_FIELDS, AGGREGATE_DATA_OUTPUT_FIELDS - ) - - return None - - def write_metadata(self, seq_id: str, feature: SeqFeature, aggregate_type: str) -> int: - return super().write_metadata(seq_id, make_parent_path(feature.parent_list), feature.id, feature.type, aggregate_type) - - def write_rows_with_feature_and_data(self, record_id: str, feature: SeqFeature, counter_dict: defaultdict[str,MultiCounter]) -> int: - b: int = 0 - - for aggregate_type, aggregate_counter in counter_dict.items(): - b += self.write_metadata(record_id, feature, aggregate_type) - b += self.write_data( - str(aggregate_counter.genome_base_freqs.sum()), - ",".join(map(str, aggregate_counter.genome_base_freqs.flat)), - ",".join(map(str, aggregate_counter.edit_site_freqs.flat)), - ",".join(map(str, aggregate_counter.edit_read_freqs.flat)), - ) - - return b - - def write_rows_with_data( - self, - record_id: str, - parent_list: list[str], - feature_id: str, - feature_type: str, - counter_dict: defaultdict[str,MultiCounter] - ) -> int: - b: int = 0 - - for aggregate_type, aggregate_counter in counter_dict.items(): - b += super().write_metadata(record_id, make_parent_path(parent_list), feature_id, feature_type, aggregate_type) - b += self.write_data( - str(aggregate_counter.genome_base_freqs.sum()), - ",".join(map(str, aggregate_counter.genome_base_freqs.flat)), - ",".join(map(str, aggregate_counter.edit_site_freqs.flat)), - ",".join(map(str, aggregate_counter.edit_read_freqs.flat)), - ) - - return b diff --git a/bin/RainFileWriters.py b/bin/RainFileWriters.py new file mode 100644 index 0000000..152cfe0 --- /dev/null +++ b/bin/RainFileWriters.py @@ -0,0 +1,284 @@ +from utils import BASE_TYPES, MATCH_MISMATCH_TYPES +from Bio.SeqFeature import ExactPosition +from MultiCounter import MultiCounter +from Bio.SeqFeature import SeqFeature +from collections import defaultdict +from typing import TextIO + +FEATURE_OUTPUT_FIELDS = [ + "SeqID", + "Parents", + "FeatureID", + "Type", + "Start", + "End", + "Strand", + "CoveredSites", + "GenomeBases", + "SiteBasePairings", + "ReadBasePairings" +] + +FEATURE_METADATA_OUTPUT_FIELDS = [ + "SeqID", + "ParentsIDs", + "FeatureID", + "Type", + "Start", + "End", + "Strand", +] + +FEATURE_DATA_OUTPUT_FIELDS = [ + "CoveredSites", + "GenomeBases", + "SiteBasePairings", + "ReadBasePairings" +] + +AGGREGATE_METADATA_OUTPUT_FIELDS = [ + "SeqID", + "ParentsIDs", + "FeatureID", + "ParentType", + "AggregateType", + "AggregationMode", +] + +AGGREGATE_DATA_OUTPUT_FIELDS = [ + "CoveredSites", + "GenomeBases", + "SiteBasePairings", + "ReadBasePairings", +] + +STR_ZERO_BASE_FREQS = ",".join("0" for _ in range(len(BASE_TYPES))) +STR_ZERO_EDIT_FREQS = ",".join("0" for _ in range(len(MATCH_MISMATCH_TYPES))) + + +def make_parent_path(parent_list: list[str]) -> str: + """ + Create a path string from an ordered list of parent IDs. + The separator is a comma, chosen because it is one of the few invalid characters in tag=value entries of the attributes field in the GFF3 format. + + Consult the GFF3 specification for details: https://github.com/The-Sequence-Ontology/Specifications/blob/master/gff3.md + """ + return ",".join(parent_list) + + +class RainFileWriter: + def __init__(self, handle: TextIO, metadata_fields: list[str], data_fields: list[str]): + self.handle = handle + self.metadata_fields: list[str] = metadata_fields + self.n_metadata: int = len(self.metadata_fields) + self.data_fields: list[str] = data_fields + self.n_data: int = len(self.data_fields) + + return None + + def write_header(self) -> int: + b: int = self.handle.write("\t".join(self.metadata_fields)) + b += self.handle.write("\t") + b += self.handle.write("\t".join(self.data_fields)) + b += self.handle.write("\n") + + return b + + def write_comment(self, comment: str) -> int: + b = self.handle.write("# ") + b += self.handle.write(comment) + b += self.handle.write("\n") + + return b + + def write_metadata(self, *metadata_values) -> int: + b: int = 0 + for val in metadata_values: + b += self.handle.write(val) + b += self.handle.write("\t") + + return b + + def write_data(self, *data_values) -> int: + b: int = 0 + for val in data_values[:-1]: + b += self.handle.write(val) + b += self.handle.write("\t") + + b += self.handle.write(data_values[-1]) + b += self.handle.write("\n") + + return b + + def write_counter_data(self, counter: MultiCounter) -> int: + b: int = self.handle.write(str(counter.genome_base_freqs.sum())) + b += self.handle.write('\t') + b += self.handle.write(",".join(map(str, counter.genome_base_freqs[0:4].flat))) + b += self.handle.write('\t') + b += self.handle.write(",".join(map(str, counter.edit_site_freqs[0:4, 0:4].flat))) + b += self.handle.write('\t') + b += self.handle.write(",".join(map(str, counter.edit_read_freqs[0:4, 0:4].flat))) + b += self.handle.write('\n') + + return b + + + +class FeatureFileWriter(RainFileWriter): + def __init__(self, handle: TextIO): + super().__init__(handle, FEATURE_METADATA_OUTPUT_FIELDS, FEATURE_DATA_OUTPUT_FIELDS) + + return None + + def write_metadata(self, record_id: str, feature: SeqFeature) -> int: + return super().write_metadata( + record_id, + make_parent_path(feature.parent_list), + feature.id, + feature.type, + str(feature.location.parts[0].start + ExactPosition(1)), + str(feature.location.parts[-1].end), + str(feature.location.strand), + ) + + def write_row_with_data( + self, record_id: str, feature: SeqFeature, counter: MultiCounter + ) -> int: + return self.write_metadata(record_id, feature) + self.write_data( + str(counter.genome_base_freqs.sum()), + ",".join(map(str, counter.genome_base_freqs[0:4].flat)), + ",".join(map(str, counter.edit_site_freqs[0:4, 0:4].flat)), + ",".join(map(str, counter.edit_read_freqs[0:4, 0:4].flat)), + ) + + def write_row_without_data(self, record_id: str, feature: SeqFeature) -> int: + return self.write_metadata(record_id, feature) + self.write_data( + "0", STR_ZERO_BASE_FREQS, STR_ZERO_EDIT_FREQS, STR_ZERO_EDIT_FREQS + ) + + +class AggregateFileWriter(RainFileWriter): + def __init__(self, handle: TextIO): + super().__init__(handle, AGGREGATE_METADATA_OUTPUT_FIELDS, AGGREGATE_DATA_OUTPUT_FIELDS) + + return None + + def write_metadata_direct( + self, + seq_id: str, + parent_ids: str, + aggregate_id: str, + parent_type: str, + aggregate_type: str, + aggregation_mode: str + ) -> int: + b: int = self.handle.write(seq_id) + b += self.handle.write('\t') + b += self.handle.write(parent_ids) + b += self.handle.write('\t') + b += self.handle.write(aggregate_id) + b += self.handle.write('\t') + b += self.handle.write(parent_type) + b += self.handle.write('\t') + b += self.handle.write(aggregate_type) + b += self.handle.write('\t') + b += self.handle.write(aggregation_mode) + b += self.handle.write('\t') + + return b + + def write_metadata( + self, seq_id: str, feature: SeqFeature, aggregate_type: str, aggregation_mode: str + ) -> int: + return super().write_metadata( + seq_id, + make_parent_path(feature.parent_list), + feature.id, + feature.type, + aggregate_type, + aggregation_mode, + ) + + def write_rows_with_feature_and_data( + self, + record_id: str, + feature: SeqFeature, + aggregation_mode: str, + counter_dict: defaultdict[str, MultiCounter], + ) -> int: + b: int = 0 + + for aggregate_type, aggregate_counter in counter_dict.items(): + b += self.write_metadata(record_id, feature, aggregate_type, aggregation_mode) + b += self.write_data( + str(aggregate_counter.genome_base_freqs.sum()), + ",".join(map(str, aggregate_counter.genome_base_freqs[0:4].flat)), + ",".join(map(str, aggregate_counter.edit_site_freqs[0:4, 0:4].flat)), + ",".join(map(str, aggregate_counter.edit_read_freqs[0:4, 0:4].flat)), + ) + + return b + + def write_rows_with_data( + self, + record_id: str, + parent_list: list[str], + feature_id: str, + feature_type: str, + aggregation_mode: str, + counter_dict: defaultdict[str, MultiCounter], + ) -> int: + b: int = 0 + + for aggregate_type, aggregate_counter in counter_dict.items(): + b += super().write_metadata( + record_id, + make_parent_path(parent_list), + feature_id, + feature_type, + aggregate_type, + aggregation_mode, + ) + b += self.write_data( + str(aggregate_counter.genome_base_freqs.sum()), + ",".join(map(str, aggregate_counter.genome_base_freqs[0:4].flat)), + ",".join(map(str, aggregate_counter.edit_site_freqs[0:4, 0:4].flat)), + ",".join(map(str, aggregate_counter.edit_read_freqs[0:4, 0:4].flat)), + ) + + return b + + def write_row_chimaera_with_data( + self, record_id: str, feature: SeqFeature, parent_feature: SeqFeature, counter: MultiCounter + ) -> int: + b: int = super().write_metadata( + record_id, + make_parent_path(feature.parent_list), + feature.id, + parent_feature.type, + feature.type, + "chimaera", + ) + b += self.write_data( + str(counter.genome_base_freqs.sum()), + ",".join(map(str, counter.genome_base_freqs[0:4].flat)), + ",".join(map(str, counter.edit_site_freqs[0:4, 0:4].flat)), + ",".join(map(str, counter.edit_read_freqs[0:4, 0:4].flat)), + ) + + return b + + def write_row_chimaera_without_data( + self, record_id: str, feature: SeqFeature, parent_feature: SeqFeature + ) -> int: + b: int = super().write_metadata( + record_id, + make_parent_path(feature.parent_list), + feature.id, + parent_feature.type, + feature.type, + "chimaera", + ) + b += self.write_data("0", STR_ZERO_BASE_FREQS, STR_ZERO_EDIT_FREQS, STR_ZERO_EDIT_FREQS) + + return b diff --git a/bin/SeqFeature_extensions.py b/bin/SeqFeature_extensions.py index 24e80ca..f171021 100644 --- a/bin/SeqFeature_extensions.py +++ b/bin/SeqFeature_extensions.py @@ -6,8 +6,10 @@ logger = logging.getLogger(__name__) -setattr(SeqFeature, "level", 0) +setattr(SeqFeature, "level", 0) +setattr(SeqFeature, "is_chimaera", False) +setattr(SeqFeature, "longest_isoform", None) def get_transcript_like(self: SeqFeature) -> list[tuple[str, str, int]]: """ @@ -26,13 +28,9 @@ def get_transcript_like(self: SeqFeature) -> list[tuple[str, str, int]]: total_cds_length += len(child) if total_cds_length > 0: - transcript_like_list.append( - (transcript_candidate.id, "CDS", total_cds_length) - ) + transcript_like_list.append((transcript_candidate.id, "CDS", total_cds_length)) elif total_exon_length > 0: - transcript_like_list.append( - (transcript_candidate.id, "exon", total_exon_length) - ) + transcript_like_list.append((transcript_candidate.id, "exon", total_exon_length)) return transcript_like_list @@ -41,15 +39,57 @@ def get_transcript_like(self: SeqFeature) -> list[tuple[str, str, int]]: setattr(SeqFeature, "parent_list", [""]) -def make_chimaera(self: SeqFeature) -> None: + +def make_chimaeras2(self: SeqFeature, record_id: str) -> None: + target_type_locations: dict[str, list[SimpleLocation | CompoundLocation]] = {} + + for transcript in self.sub_features: + for child in transcript.sub_features: + type_locations: Optional[list[SimpleLocation | CompoundLocation]] = ( + target_type_locations.get(child.type, None) + ) + if type_locations: + type_locations.extend(child.location.parts) + else: + target_type_locations[child.type] = child.location.parts + + chimaeric_type_locations: dict[str, SimpleLocation | CompoundLocation] = { + key: location_union(location_parts) for key, location_parts in target_type_locations.items() + } + + for key, location in chimaeric_type_locations.items(): + chimaera: SeqFeature = SeqFeature( + location=location, + id=f"{self.id}-{key}-chimaera", + type=key+"-chimaera", + qualifiers={"Parent": self.id} + ) + + # if key == "exon" or key == "CDS": + # logging.info( + # f"Record {record_id} · Created {key} chimaera of feature {self.id}: {len(transcript_like_list)} transcripts were merged into one transcript of {len(chimaeric_location_cds_or_exon.parts)} elements" + # ) + + chimaera.sub_features = [] + chimaera.is_chimaera = True + self.sub_features.append(chimaera) + + return None + +setattr(SeqFeature, "make_chimaeras2", make_chimaeras2) + + +def make_chimaeras(self: SeqFeature, record_id: str) -> list[SeqFeature]: """ If the feature contains """ if hasattr(self, "sub_features"): if len(self.sub_features) == 0: - return None + return [] else: - return None + return [] + + new_chimaeras: list[SeqFeature] = [] transcript_like_list: list[SeqFeature] = list( filter( @@ -59,46 +99,102 @@ def make_chimaera(self: SeqFeature) -> None: ) if len(transcript_like_list) == 0: - chimaeric_type: str = "exon" + chimaeric_type_cds_or_exon: str = "exon" transcript_like_list: list[SeqFeature] = list( filter( - lambda transcript: any(map(lambda part: part.type == "exon", transcript.sub_features)), + lambda transcript: any( + map(lambda part: part.type == "exon", transcript.sub_features) + ), self.sub_features, ) ) else: - chimaeric_type: str = "CDS" + chimaeric_type_cds_or_exon: str = "CDS" if len(transcript_like_list) == 0: return None - - target_locations: list[SimpleLocation | CompoundLocation] = [] + target_locations_cds_or_exon: list[SimpleLocation | CompoundLocation] = [] + target_locations_five_prime_utr: list[SimpleLocation | CompoundLocation] = [] + target_locations_three_prime_utr: list[SimpleLocation | CompoundLocation] = [] for transcript in transcript_like_list: - target_locations.extend( - list(map( - lambda part: part.location, - filter(lambda part: part.type == chimaeric_type, transcript.sub_features), - )) + target_locations_cds_or_exon.extend( + list( + map( + lambda part: part.location, + filter( + lambda part: part.type == chimaeric_type_cds_or_exon, + transcript.sub_features, + ), + ) + ) + ) + target_locations_five_prime_utr.extend( + list( + map( + lambda part: part.location, + filter(lambda part: part.type == "five_prime_utr", transcript.sub_features), + ) + ) + ) + target_locations_three_prime_utr.extend( + list( + map( + lambda part: part.location, + filter(lambda part: part.type == "three_prime_utr", transcript.sub_features), + ) + ) ) - chimaeric_location: SimpleLocation | CompoundLocation = location_union( - target_locations + chimaeric_location_cds_or_exon: SimpleLocation | CompoundLocation = location_union( + target_locations_cds_or_exon + ) + logging.info( + f"Record {record_id} · Created {chimaeric_type_cds_or_exon} chimaera of feature {self.id}: {len(transcript_like_list)} transcripts were merged into one transcript of {len(chimaeric_location_cds_or_exon.parts)} elements" ) - logging.info(f"Created {chimaeric_type} chimaera of feature {self.id}: {len(transcript_like_list)} transcripts were merged into one transcript of {len(chimaeric_location.parts)} elements") - chimaeric_feature: SeqFeature = SeqFeature( - location=chimaeric_location, - type=chimaeric_type + "-chimaera", + chimaeric_feature_cds_or_exon: SeqFeature = SeqFeature( + location=chimaeric_location_cds_or_exon, + type=chimaeric_type_cds_or_exon + "-chimaera", id=self.id + "-chimaera", qualifiers={"Parent": self.id}, ) + chimaeric_feature_cds_or_exon.is_chimaera = True + chimaeric_feature_cds_or_exon.sub_features = [] + self.sub_features.append(chimaeric_feature_cds_or_exon) + new_chimaeras.append(chimaeric_feature_cds_or_exon) + + if len(target_locations_five_prime_utr) > 0: + chimaeric_location_five_prime_utr: SimpleLocation | CompoundLocation = location_union( + target_locations_five_prime_utr + ).parts[0] # Pick only the first element so that there is only one 5'-UTR + chimaeric_feature_five_prime_utr: SeqFeature = SeqFeature( + location=chimaeric_location_five_prime_utr, + type="five_prime_utr-chimaera", + id=self.id + "-chimaera", + qualifiers={"Parent": self.id}, + ) + chimaeric_feature_five_prime_utr.is_chimaera = True + chimaeric_feature_five_prime_utr.sub_features = [] + self.sub_features.append(chimaeric_feature_five_prime_utr) + new_chimaeras.append(chimaeric_feature_five_prime_utr) + + if len(target_locations_three_prime_utr) > 0: + chimaeric_location_three_prime_utr: SimpleLocation | CompoundLocation = location_union( + target_locations_three_prime_utr + ).parts[-1] # Pick only the last element so that there is only one 3'-UTR + chimaeric_feature_three_prime_utr: SeqFeature = SeqFeature( + location=chimaeric_location_three_prime_utr, + type="three_prime_utr-chimaera", + id=self.id + "-chimaera", + qualifiers={"Parent": self.id}, + ) + chimaeric_feature_three_prime_utr.is_chimaera = True + chimaeric_feature_three_prime_utr.sub_features = [] + self.sub_features.append(chimaeric_feature_three_prime_utr) + new_chimaeras.append(chimaeric_feature_three_prime_utr) - chimaeric_feature.sub_features = [] - - self.sub_features.append(chimaeric_feature) - - return None + return new_chimaeras -setattr(SeqFeature, "make_chimaera", make_chimaera) +setattr(SeqFeature, "make_chimaeras", make_chimaeras) diff --git a/bin/new-pluviometer.py b/bin/new-pluviometer.py deleted file mode 100644 index b38bc83..0000000 --- a/bin/new-pluviometer.py +++ /dev/null @@ -1,317 +0,0 @@ -from BCBio import GFF -from Bio.SeqRecord import SeqRecord -from Bio.SeqFeature import SeqFeature -from typing import TextIO, Optional, Generator -import numpy as np -from numpy.typing import NDArray -import argparse -from MultiCounter import MultiCounter -from SiteFilter import SiteFilter -from utils import SiteVariantData, condense -from contextlib import nullcontext -import sys -from site_variant_readers import ( - RNAVariantReader, - Reditools2Reader, - Reditools3Reader, - Jacusa2Reader, -) -from FeatureOutputWriter import FeatureFileWriter -from collections import deque, defaultdict -import progressbar -import math - -def parse_cli_input() -> argparse.Namespace: - """Parse command line input""" - - parser = argparse.ArgumentParser(description="Rain counter") - parser.add_argument( - "--sites", - "-s", - type=str, - required=True, - help="File containing per-site base alteration data", - ) - parser.add_argument( - "--gff", - "-g", - type=str, - required=True, - help="Reference genome annotations (GFF3 file)", - ) - parser.add_argument( - "--output", - "-o", - default="", - type=str, - help="Name of the output file (leave empty to write to stdout)", - ) - parser.add_argument( - "--format", - "-f", - type=str, - choices=["reditools2", "reditools3", "jacusa2", "sapin"], - default="reditools3", - help="Sites file format", - ) - parser.add_argument( - "--cov", - "-c", - type=int, - default=0, - help="Site coverage threshold for counting editions", - ) - parser.add_argument( - "--edit_threshold", - "-t", - type=int, - default=1, - help="Minimum number of edited reads for counting a site as edited", - ) - parser.add_argument( - "--aggregation_mode", - "-a", - type=str, - default="all", - choices=["all", "cds_longest"], - help='Mode for aggregating counts: "all" aggregates features of every transcript; "cds_longest" aggregates features of the longest CDS or non-coding transcript', - ) - parser.add_argument( - "--progress", action="store_true", default="false", help="Display progress bar" - ) - - return parser.parse_args() - - - - -class RecordManager: - def __init__(self, record: SeqRecord, writer: FeatureFileWriter, filter: SiteFilter): - - self.record: SeqRecord = record - self.writer: FeatureFileWriter = writer - """Dict of multicounters with feature ids as keys""" - self.counters: defaultdict[str, MultiCounter] = defaultdict(self.counter_factory) - """Set of features whose multicounters are currently being updated""" - self.active_features: dict[str,SeqFeature] = dict() - self.filter: SiteFilter = filter - - # Flatten the feature hierarchy - feature_list: list[SeqFeature] = [] - - for feature in record.features: - self._flatten_hierarchy(feature_list, feature, 1) - - self.nb_targets = len(feature_list) - - nb_targets_d_format = math.floor(math.log(self.nb_targets, 10)) - - # Create deques of features sorted by their start position and their end position - # These deques are used for loading and unloading features from the `active_features` set - feature_list.sort(key=lambda feature: feature.location.start) - self.activation_deque: deque[tuple[int, list[SeqFeature]]] = condense(feature_list, "start") - print(f"Features to activate: {len(self.activation_deque)}") - - feature_list.sort(key=lambda feature: feature.location.end) - self.deactivation_deque: deque[tuple[int, list[SeqFeature]]] = condense(feature_list, "end") - print(f"Features to deactivate: {len(self.deactivation_deque)}") - - - # self.features_start_first: deque[SeqFeature] = deque( - # sorted(feature_list, key=lambda feature: feature.location.start), - # ) - - # self.features_end_first: deque[SeqFeature] = deque( - # sorted(feature_list, key=lambda feature: feature.location.end), - # ) - - # self.switchlist: deque[tuple[int,str,str]] = deque() - - - self.use_progress_bar = args.progress - if self.use_progress_bar: - self.progress_bar: progressbar.ProgressBar = progressbar.ProgressBar( - max_value=self.nb_targets, - widgets=[ - f"Record {self.record.id}: Processed target feature ", - progressbar.Counter( - format=f"%(value)0{nb_targets_d_format}d out of %(max_value)d" - ), - " (", - progressbar.Percentage(), - ") ", - progressbar.Bar("█", "|", "|"), - " ", - progressbar.Timer(), - " - ", - progressbar.SmoothingETA(), - ], - poll_interval=1, # Updates every 1 second - ) - - def counter_factory(self) -> MultiCounter: - return MultiCounter(self.filter) - - def _flatten_hierarchy(self, feature_list, feature, level) -> None: - """Add elements to a flattened list of features by preorder traversal of a feature hierarchy""" - feature_list.append(feature) - feature.level = level - for child in feature.sub_features: - self._flatten_hierarchy(feature_list, child, level + 1) - - return None - - def update_active_counters(self, site_data: SiteVariantData) -> None: - """ - Update the multicounters matching the ID of features in the `active_features` set. - A new multicounter is created if no matching ID is found. - """ - for feature_key, feature in self.active_features.items(): - counter = self.counters[feature_key] - # if counter: - # print(f"counter found for feature {key}") - if feature.location.strand == site_data.strand: - counter.update(site_data) - - return None - - # def update_queues(self, new_position: int) -> None: - # remove_list: list[str] = [] - - # delete_list: list[str] = [] - # for key, feature in self.active_features.items(): - # if feature.location.end < new_position: - # delete_list.append(key) - - # for key in delete_list: - # del self.active_features[key] - # # self.progress_bar.next() - - # while len(self.features_start_first) > 0 and self.features_start_first[0].location.start <= new_position: - # new_feature: SeqFeature = self.features_start_first.popleft() - - # if new_feature.location.end > new_position: - # self.active_features[new_feature.id] = new_feature - # else: - # # self.progress_bar.next() - # pass - # # print(f"activated feature {new_feature.id}") - - # # while len(self.features_end_first) > 0 and self.features_end_first[0].location.end < new_position: - # # old_feature: SeqFeature = self.features_end_first.popleft() - - # # remove_list.append(old_feature.id) - # # # self.active_features.discard(old_feature.id) - # # if old_feature.level == 1: - # # # print(f"deactivated feature {old_feature.id}") - # # self.checkout(old_feature) - - # # for elem in remove_list: - # # self.active_features.discard(elem) - - # return None - - def update_queues(self, new_position: int) -> None: - while len(self.activation_deque) > 0 and self.activation_deque[0][0] <= new_position: - feature_list = self.activation_deque.popleft()[1] - for feature in feature_list: - if feature.location.end < new_position: - if args.progress: - self.progress_bar.next() - else: - self.active_features[feature.id] = feature - - while len(self.deactivation_deque) > 0 and self.deactivation_deque[0][0] < new_position: - feature_list = self.deactivation_deque.popleft()[1] - for feature in feature_list: - self.active_features.pop(feature.id, None) - if feature.level == 1: - self.checkout(feature) - # if args.progress: - # self.progress_bar.next() - - - def checkout(self, feature: SeqFeature) -> None: - # print(f"checking out feature {feature.id}") - if feature.id in self.active_features: - del self.active_features[feature.id] - if args.progress: - self.progress_bar.next() - - counter: Optional[MultiCounter] = self.counters.get(feature.id, MultiCounter(self.filter)) - # print(self.counters.keys()) - if counter: - # print(f"writing data for feature {feature.id}") - self.writer.write_feature_with_data(self.record, feature, counter) - del self.counters[feature.id] - - for child in feature.sub_features: - self.checkout(child) - - return None - - def is_finished(self) -> bool: - result = len(self.active_features) + len(self.activation_deque) + len(self.deactivation_deque) == 0 - # if len(self.activation_deque) == 0: - # print("Activation queue empty") - # if len(self.deactivation_deque) == 0: - # print("Deactivation queue empty") - # if result: - # print("All queues are empty") - - return result - - def launch_counting(self, reader: RNAVariantReader) -> None: - svdata: Optional[SiteVariantData] = reader.read() - - while svdata and not self.is_finished(): - # print(self.active_features) - self.update_queues(svdata.position) - self.update_active_counters(svdata) - svdata: SiteVariantData = reader.read() - # print(f"Active counters: {len(self.counters)} - Active features: {len(self.active_features)}") - - - -if __name__ == "__main__": - args = parse_cli_input() - with ( - open(args.gff) as gff_handle, - open(args.sites) as sv_handle, - open(args.output, "w") - if len(args.output) > 0 - else nullcontext(sys.stdout) as output_handle, - ): - match args.format: - case "reditools2": - sv_reader: RNAVariantReader = Reditools2Reader(sv_handle) - case "reditools3": - sv_reader: RNAVariantReader = Reditools3Reader(sv_handle) - case "jacusa2": - sv_reader: RNAVariantReader = Jacusa2Reader(sv_handle) - case _: - raise Exception(f'Unimplemented format "{args.format}"') - - writer: FeatureFileWriter = FeatureFileWriter(output_handle) - writer.write_comment(f"input format: {args.format}") - writer.write_header() - - records: Generator[SeqRecord, None, None] = GFF.parse(gff_handle) - - global_filter: SiteFilter = SiteFilter( - cov_threshold=args.cov, edit_threshold=args.edit_threshold - ) - - for record in records: - manager: RecordManager = RecordManager(record, writer, global_filter) - manager.launch_counting(sv_reader) - - print(f"Active: {len(manager.active_features)}") - # if len(manager.active_features) > 0: - # print(*map(lambda x: x.id, manager.active_features)) - print(f"To activate: {len(manager.activation_deque)}") - # if len(manager.activation_deque) > 0: - # print(*map(lambda x: x.id, manager.activation_deque)) - print(f"To deactivate: {len(manager.deactivation_deque)}") - # if len(manager.deactivation_deque) > 0: - # print(*map(lambda x: x.id, manager.deactivation_deque)) diff --git a/bin/pluviometer.py b/bin/pluviometer.py index 0b8adbe..d5c616c 100755 --- a/bin/pluviometer.py +++ b/bin/pluviometer.py @@ -1,19 +1,20 @@ #!/usr/bin/env python -from FeatureOutputWriter import FeatureFileWriter, AggregateFileWriter +from RainFileWriters import FeatureFileWriter, AggregateFileWriter from typing import Any, Optional, Generator, Callable, TextIO +from Bio.SeqFeature import SimpleLocation, CompoundLocation from SeqFeature_extensions import SeqFeature from collections import deque, defaultdict from dataclasses import dataclass, field from MultiCounter import MultiCounter from Bio.SeqRecord import SeqRecord from SiteFilter import SiteFilter -from utils import SiteVariantData from site_variant_readers import ( RNAVariantReader, Reditools2Reader, Reditools3Reader, Jacusa2Reader, ) +from utils import SiteVariantData from natsort import natsorted import multiprocessing from BCBio import GFF @@ -52,7 +53,13 @@ class CountingContext: def __init__(self, aggregate_writer: AggregateFileWriter, filter: SiteFilter): self.aggregate_writer: AggregateFileWriter = aggregate_writer self.filter: SiteFilter = filter - self.aggregate_counters: defaultdict[str, MultiCounter] = defaultdict( + self.longest_isoform_aggregate_counters: defaultdict[str, MultiCounter] = defaultdict( + DefaultMultiCounterFactory(self.filter) + ) + self.chimaera_aggregate_counters: defaultdict[str, MultiCounter] = defaultdict( + DefaultMultiCounterFactory(self.filter) + ) + self.all_isoforms_aggregate_counters: defaultdict[str, MultiCounter] = defaultdict( DefaultMultiCounterFactory(self.filter) ) self.total_counter: MultiCounter = MultiCounter(self.filter) @@ -61,13 +68,23 @@ def __init__(self, aggregate_writer: AggregateFileWriter, filter: SiteFilter): def update_aggregate_counters(self, new_counters: defaultdict[str, MultiCounter]) -> None: for counter_type, new_counter in new_counters.items(): - target_counter: MultiCounter = self.aggregate_counters[counter_type] + target_counter: MultiCounter = self.longest_isoform_aggregate_counters[counter_type] target_counter.merge(new_counter) return None -class RecordCountingContext(CountingContext): +def merge_aggregation_counter_dicts( + dst: defaultdict[str, MultiCounter], src: defaultdict[str, MultiCounter] +) -> None: + for src_type, src_counter in src.items(): + dst_counter: MultiCounter = dst[src_type] + dst_counter.merge(src_counter) + + return None + + +class RecordCountingContext: def __init__( self, feature_writer: FeatureFileWriter, @@ -75,7 +92,20 @@ def __init__( filter: SiteFilter, use_progress_bar: bool, ): - super().__init__(aggregate_writer, filter) + # super().__init__(aggregate_writer, filter) + self.aggregate_writer: AggregateFileWriter = aggregate_writer + self.filter: SiteFilter = filter + self.longest_isoform_aggregate_counters: defaultdict[str, MultiCounter] = defaultdict( + DefaultMultiCounterFactory(self.filter) + ) + self.chimaera_aggregate_counters: defaultdict[str, MultiCounter] = defaultdict( + DefaultMultiCounterFactory(self.filter) + ) + self.all_isoforms_aggregate_counters: defaultdict[str, MultiCounter] = defaultdict( + DefaultMultiCounterFactory(self.filter) + ) + self.total_counter: MultiCounter = MultiCounter(self.filter) + self.active_features: dict[str, SeqFeature] = dict() self.feature_writer: FeatureFileWriter = feature_writer self.counters: defaultdict[str, MultiCounter] = defaultdict( @@ -86,13 +116,30 @@ def __init__( self.svdata: Optional[SiteVariantData] = None self.deactivation_list: list[SeqFeature] = [] - if self.use_progress_bar: - self.progbar_increment = self._active_progbar_increment - else: - self.progbar_increment = self._inactive_progbar_increment + self.progbar_increment: Callable = ( + self._active_progbar_increment + if self.use_progress_bar + else self._inactive_progbar_increment + ) return None - + + def update_aggregate_counters( + self, aggregate_counter_tag: str, new_counters: defaultdict[str, MultiCounter] + ) -> None: + """ + Update a dict of aggregate counters by merging them with matching items in another dict of counters. + The tag refers to the different kinds of aggregate counters: "all_isoforms", "chimaera", and "longest_isoform" + """ + aggregate_counters: defaultdict[str, MultiCounter] = getattr( + self, aggregate_counter_tag + "_aggregate_counters" + ) + for counter_type, new_counter in new_counters.items(): + target_counter: MultiCounter = aggregate_counters[counter_type] + target_counter.merge(new_counter) + + return None + def set_record(self, record: SeqRecord) -> None: logging.info(f"Switching to record {record.id}") if len(self.active_features) != 0: @@ -167,17 +214,16 @@ def load_action_queue( """ # Iterate over the `parts` of a location for compatibility with `SimpleLocation` and `CompoundLocation` - # assert root_feature.location assert root_feature.location - # print(root_feature.location.parts) feature_strand: Optional[int] = root_feature.location.parts[0].strand root_feature.level = level + if not hasattr(root_feature, "is_chimaera"): + root_feature.is_chimaera = False root_feature.parent_list = parent_list - # if "chimaera" not in root_feature.type: if level == 1: - root_feature.make_chimaera() + root_feature.make_chimaeras2(self.record.id) for part in root_feature.location.parts: if feature_strand != part.strand: @@ -185,10 +231,17 @@ def load_action_queue( f"feature {root_feature.id} contains parts on different strands ({feature_strand} and {part.strand}). I cannot work with this!" ) + old_part: Optional[SimpleLocation | CompoundLocation] = None + + for part in root_feature.location.parts: + if old_part: + if old_part.contains(part.start) or old_part.contains(part.stop): + raise Exception(f"feature {root_feature.id} has a compound location containing overlapping parts. There must be no overlapping.") + actions: QueueActionList = location_actions[int(part.start)] actions.activate.append(root_feature) - actions: QueueActionList = location_actions[int(part.end)] + actions = location_actions[int(part.end)] actions.deactivate.append(root_feature) # Visit children @@ -209,7 +262,7 @@ def state_update_cycle(self, new_position: int) -> None: while ( len(self.action_queue) > 0 and self.action_queue[0][0] < new_position - ): # Use < instead of <= because of Python's right-exclusive indexing + ): # Use < instead of <= because of Python's right-exclusive indfgexing _, actions = self.action_queue.popleft() visited_positions += 1 @@ -218,7 +271,7 @@ def state_update_cycle(self, new_position: int) -> None: for feature in actions.deactivate: if feature.level == 1: - self.checkout(feature) + self.checkout(feature, None) self.active_features.pop(feature.id, None) @@ -238,7 +291,7 @@ def flush_queues(self) -> None: for feature in actions.deactivate: if feature.level == 1: - self.checkout(feature) + self.checkout(feature, None) self.active_features.pop(feature.id, None) logging.info( @@ -247,32 +300,97 @@ def flush_queues(self) -> None: return None - def checkout(self, feature: SeqFeature) -> defaultdict[str, MultiCounter]: + def checkout(self, feature: SeqFeature, parent_feature: Optional[SeqFeature]) -> None: self.active_features.pop(feature.id, None) # Counter for the feature itself - counter: Optional[MultiCounter] = self.counters.get(feature.id, None) + feature_counter: Optional[MultiCounter] = self.counters.get(feature.id, None) + # aggregate_counter: Optional[MultiCounter] = self.counters.get(feature.id, None) - if counter: - self.feature_writer.write_row_with_data(self.record.id, feature, counter) + assert self.record.id + + if feature_counter: + if feature.is_chimaera: + assert parent_feature # A chimaera must always have a parent feature (a gene) + self.aggregate_writer.write_row_chimaera_with_data( + self.record.id, feature, parent_feature, feature_counter + ) + self.chimaera_aggregate_counters[feature.type].merge(feature_counter) + else: + self.feature_writer.write_row_with_data(self.record.id, feature, feature_counter) del self.counters[feature.id] else: - self.feature_writer.write_row_without_data(self.record.id, feature) + if feature.is_chimaera: + assert parent_feature + self.aggregate_writer.write_row_chimaera_without_data( + self.record.id, feature, parent_feature + ) + else: + self.feature_writer.write_row_without_data(self.record.id, feature) + + all_isoforms_aggregation_counters: Optional[defaultdict[str, MultiCounter]] = None + + assert self.record.id # Placate Pylance # Aggregation counters from the feature's sub-features if feature.level == 1: - aggregation_counters: dict[str, MultiCounter] = self.aggregate_level1(feature) + ( + level1_longest_isoform_aggregation_counters, + level1_all_isoforms_aggregation_counters, + ) = self.aggregate_level1(feature) + merge_aggregation_counter_dicts( + self.all_isoforms_aggregate_counters, level1_all_isoforms_aggregation_counters + ) + + for aggregate_type, aggregate_counter in level1_longest_isoform_aggregation_counters.items(): + self.aggregate_writer.write_metadata_direct( + seq_id=self.record.id, + parent_ids=".," + feature.id, + aggregate_id=feature.longest_isoform + "-longest_isoform", + parent_type=feature.type, + aggregate_type=aggregate_type, + aggregation_mode="longest_isoform" + ) + self.aggregate_writer.write_counter_data(aggregate_counter) + + for aggregate_type, aggregate_counter in level1_all_isoforms_aggregation_counters.items(): + self.aggregate_writer.write_metadata_direct( + seq_id=self.record.id, + parent_ids=".," + feature.id, + aggregate_id=feature.id + "-all_isoforms", + parent_type=feature.type, + aggregate_type=aggregate_type, + aggregation_mode="all_isoforms" + ) + self.aggregate_writer.write_counter_data(aggregate_counter) else: - aggregation_counters: dict[str, MultiCounter] = self.aggregate_children(feature) + feature_aggregation_counters = self.aggregate_children(feature) + for aggregate_type, aggregate_counter in feature_aggregation_counters.items(): + self.aggregate_writer.write_metadata_direct( + seq_id=self.record.id, + parent_ids=','.join(feature.parent_list), + aggregate_id=feature.id, + parent_type=parent_feature.type, + aggregate_type=aggregate_type, + aggregation_mode="feature" + ) + self.aggregate_writer.write_counter_data(aggregate_counter) # Recursively check-out children for child in feature.sub_features: - self.checkout(child) + self.checkout(child, feature) - return aggregation_counters + return None - def aggregate_level1(self, feature: SeqFeature) -> dict[str, MultiCounter]: - level1_aggregation_counters: defaultdict[str, MultiCounter] = defaultdict(DefaultMultiCounterFactory(self.filter)) + def aggregate_level1( + self, feature: SeqFeature + ) -> tuple[defaultdict[str, MultiCounter], defaultdict[str, MultiCounter]]: + level1_longest_isoform_aggregation_counters: defaultdict[str, MultiCounter] = defaultdict( + DefaultMultiCounterFactory(self.filter) + ) + level1_all_isoforms_aggregation_counters: defaultdict[str, MultiCounter] = defaultdict( + DefaultMultiCounterFactory(self.filter) + ) # List of tuples of transcript-like sub-features. In each tuple: # - 0: ID of the sub-feature @@ -282,9 +400,9 @@ def aggregate_level1(self, feature: SeqFeature) -> dict[str, MultiCounter]: feature.get_transcript_like() ) # Custom method added to the class - # Select the transcript-like feature that is representative of this gene. + # Select the transcript-like feature that is representative of this gene: the longest isoform. # If there are CDS sub-features, select the onte with greatest total CDS length. Elsewise, select the sub-feature with the greatest total exon length. - representative_feature_id: str = "" + longest_isoform_id: str = "" has_cds: bool = False max_total_length: int = 0 @@ -292,49 +410,61 @@ def aggregate_level1(self, feature: SeqFeature) -> dict[str, MultiCounter]: if child_type == "CDS": if has_cds: if child_length > max_total_length: - representative_feature_id = child_id + longest_isoform_id = child_id max_total_length = child_length else: - representative_feature_id = child_id + longest_isoform_id = child_id max_total_length = child_length has_cds = True elif child_type == "exon": if has_cds: continue elif child_length > max_total_length: - representative_feature_id = child_id + longest_isoform_id = child_id max_total_length = child_length + feature.longest_isoform = longest_isoform_id + logging.info( - f"Record {self.record.id}, gene {feature.id}: Selected the transcript {representative_feature_id} with {'CDS' if has_cds else 'exons'} as the representative feature." + f"Record {self.record.id}, gene {feature.id}: The longest isoform is {longest_isoform_id}." ) - # Perform aggregations, selecting only the "representative feature" + # Perform aggregations for child in feature.sub_features: # Compute aggregates in the child. Recursively aggregates on all its children. - aggregation_counters_from_child = self.aggregate_children(child) + aggregation_counters_from_child: defaultdict[str, MultiCounter] = ( + self.aggregate_children(child) + ) - if child.id == representative_feature_id: + merge_aggregation_counter_dicts( + level1_all_isoforms_aggregation_counters, aggregation_counters_from_child + ) + + if child.id == longest_isoform_id: # Merge the aggregates from the child with all the other aggregates under this feature - for ( - child_aggregation_type, - child_aggregation_counter, - ) in aggregation_counters_from_child.items(): - aggregation_counter: MultiCounter = level1_aggregation_counters[ - child_aggregation_type - ] - aggregation_counter.merge(child_aggregation_counter) + merge_aggregation_counter_dicts( + level1_longest_isoform_aggregation_counters, aggregation_counters_from_child + ) - self.aggregate_writer.write_rows_with_feature_and_data( - self.record.id, feature, level1_aggregation_counters - ) + # assert self.record.id + # self.aggregate_writer.write_rows_with_feature_and_data( + # self.record.id, feature, "longest_isoform", level1_longest_isoform_aggregation_counters + # ) + # self.aggregate_writer.write_rows_with_feature_and_data( + # self.record.id, feature, "all_isoforms", level1_all_isoforms_aggregation_counters + # ) # Merge the feature-level aggregation counters into the record-level aggregation counters - self.update_aggregate_counters(level1_aggregation_counters) + self.update_aggregate_counters( + "longest_isoform", level1_longest_isoform_aggregation_counters + ) - return level1_aggregation_counters + return ( + level1_longest_isoform_aggregation_counters, + level1_all_isoforms_aggregation_counters, + ) - def aggregate_children(self, feature: SeqFeature) -> dict[str, MultiCounter]: + def aggregate_children(self, feature: SeqFeature) -> defaultdict[str, MultiCounter]: aggregation_counters: defaultdict[str, MultiCounter] = defaultdict( DefaultMultiCounterFactory(self.filter) ) @@ -357,9 +487,9 @@ def aggregate_children(self, feature: SeqFeature) -> dict[str, MultiCounter]: if feature_counter: aggregation_counter.merge(feature_counter) - self.aggregate_writer.write_rows_with_feature_and_data( - self.record.id, feature, aggregation_counters - ) + # self.aggregate_writer.write_rows_with_feature_and_data( + # self.record.id, feature, "feature", aggregation_counters + # ) return aggregation_counters @@ -474,22 +604,24 @@ def parse_cli_input() -> argparse.Namespace: class DefaultMultiCounterFactory: + """Callable class to enable the pickling of a MultiCounter factory for multiprocessing (lambdas cannot be pickled)""" + def __init__(self, filter: SiteFilter): self.filter: SiteFilter = filter return None - + def __call__(self): return MultiCounter(self.filter) -def run_job(record: SeqRecord) -> dict[str,Any]: +def run_job(record: SeqRecord) -> dict[str, Any]: """ A wrapper function for performing counting parallelized by record. The return value is a dict containing all the information needed for integrating the output of all records after the computations are finished. """ assert record.id # Stupid assertion for pylance - logging.info(f"Record {record.id} · Start processing the record") + logging.info(f"Record {record.id} · Record parsed. Counting beings.") tmp_feature_output_file: str = tempfile.mkstemp()[1] tmp_aggregate_output_file: str = tempfile.mkstemp()[1] @@ -497,8 +629,8 @@ def run_job(record: SeqRecord) -> dict[str,Any]: with ( open(args.sites) as sv_handle, open(tmp_feature_output_file, "w") as tmp_feature_output_handle, - open(tmp_aggregate_output_file, "w") as tmp_aggregate_output_handle - ): + open(tmp_aggregate_output_file, "w") as tmp_aggregate_output_handle, + ): # Set up output feature_writer: FeatureFileWriter = FeatureFileWriter(tmp_feature_output_handle) aggregate_writer: AggregateFileWriter = AggregateFileWriter(tmp_aggregate_output_handle) @@ -516,21 +648,50 @@ def run_job(record: SeqRecord) -> dict[str,Any]: record_ctx.launch_counting(reader) # Write aggregate counter data of the record - aggregate_writer.write_rows_with_data(record.id, ["."], ".", ".", record_ctx.aggregate_counters) + aggregate_writer.write_rows_with_data( + record.id, + ["."], + ".", + ".", + "longest_isoform", + record_ctx.longest_isoform_aggregate_counters, + ) + aggregate_writer.write_rows_with_data( + record.id, + ["."], + ".", + ".", + "all_isoforms", + record_ctx.all_isoforms_aggregate_counters, + ) + aggregate_writer.write_rows_with_data( + record.id, + ["."], + ".", + ".", + "chimaera", + record_ctx.chimaera_aggregate_counters, + ) # Write the total counter data of the record. A dummy dict needs to be created to use the `write_rows_with_data` method - total_counter_dict: defaultdict[str, MultiCounter] = defaultdict(lambda: MultiCounter(genome_filter)) + total_counter_dict: defaultdict[str, MultiCounter] = defaultdict( + lambda: MultiCounter(genome_filter) + ) total_counter_dict["."] = record_ctx.total_counter - aggregate_writer.write_rows_with_data(record.id, ["."], ".", ".", total_counter_dict) + aggregate_writer.write_rows_with_data( + record.id, ["."], ".", ".", "all_sites", total_counter_dict + ) return { "record_id": record.id, "tmp_feature_output_file": tmp_feature_output_file, "tmp_aggregate_output_file": tmp_aggregate_output_file, - "aggregate_counters": record_ctx.aggregate_counters, + "chimaera_aggregate_counters": record_ctx.chimaera_aggregate_counters, + "longest_isoform_aggregate_counters": record_ctx.longest_isoform_aggregate_counters, + "all_isoforms_aggregate_counters": record_ctx.all_isoforms_aggregate_counters, "total_counter": record_ctx.total_counter, } - + if __name__ == "__main__": global args @@ -540,9 +701,9 @@ def run_job(record: SeqRecord) -> dict[str,Any]: log_filename: str = args.output + ".pluviometer.log" if args.output else "pluviometer.log" logging.basicConfig(filename=log_filename, level=logging.INFO, format=LOGGING_FORMAT) logging.info(f"Pluviometer started. Log file: {log_filename}") - feature_output_filename: str = args.output + ".features.tsv" if args.output else "features.tsv" + feature_output_filename: str = args.output + "features.tsv" if args.output else "features.tsv" aggregate_output_filename: str = ( - args.output + ".aggregates.tsv" if args.output else "aggregates.tsv" + args.output + "aggregates.tsv" if args.output else "aggregates.tsv" ) global reader_factory @@ -562,63 +723,21 @@ def run_job(record: SeqRecord) -> dict[str,Any]: case _: raise Exception(f'Unimplemented format "{args.format}"') - # feature_writer: FeatureFileWriter = FeatureFileWriter(feature_output_handle) - # feature_writer.write_header() - - # aggregate_writer = AggregateFileWriter(aggregate_output_handle) - # aggregate_writer.write_header() - logging.info("Parsing GFF3 file...") records: Generator[SeqRecord, None, None] = GFF.parse(gff_handle) - logging.info("GFF3 parsing completed.") - - # global_filter: SiteFilter = SiteFilter( - # cov_threshold=args.cov, edit_threshold=args.edit_threshold - # ) - - # genome_ctx: CountingContext = CountingContext(aggregate_writer, global_filter) - - # record_ctx = RecordCountingContext( - # feature_writer, aggregate_writer, global_filter, args.progress - # ) - # logging.info("Fetching records...") - # for i, record in enumerate(records): - # with open(args.sites) as sv_handle: - # sv_reader = reader_factory(sv_handle) - # logging.info(f"Record {record.id} setup") - # record_ctx.set_record(record) - # logging.info(f"Start counting on record {record.id}") - # record_ctx.launch_counting(sv_reader) - # logging.info(f"Ended counting on record {record.id}") - - # record_ctx.aggregate_writer.write_rows_with_data( - # record.id, ["."], ".", ".", record_ctx.aggregate_counters - # ) - # genome_ctx.update_aggregate_counters(record_ctx.aggregate_counters) - # genome_ctx.total_counter.merge(record_ctx.total_counter) - - # total_counter_dict: defaultdict = defaultdict(None) - # total_counter_dict[record.id] = record_ctx.total_counter - - # record_ctx.aggregate_writer.write_rows_with_data( - # record.id, ["."], ".", ".", total_counter_dict - # ) - - # genome_ctx.aggregate_writer.write_rows_with_data( - # ".", ["."], ".", ".", genome_ctx.aggregate_counters - # ) - - # # Finally, create a dummy defaultdict to use the same method for writing the genome total - # total_counter_dict: defaultdict = defaultdict(None) - # total_counter_dict["."] = genome_ctx.total_counter - - # genome_ctx.aggregate_writer.write_rows_with_data(".", ["."], ".", ".", total_counter_dict) genome_filter: SiteFilter = SiteFilter( cov_threshold=args.cov, edit_threshold=args.edit_threshold ) genome_total_counter: MultiCounter = MultiCounter(genome_filter) - genome_aggregate_counters: defaultdict[str, MultiCounter] = defaultdict( + + genome_longest_isoform_aggregate_counters: defaultdict[str, MultiCounter] = defaultdict( + lambda: MultiCounter(genome_filter) + ) + genome_all_isoforms_aggregate_counters: defaultdict[str, MultiCounter] = defaultdict( + lambda: MultiCounter(genome_filter) + ) + genome_chimaera_aggregate_counters: defaultdict[str, MultiCounter] = defaultdict( lambda: MultiCounter(genome_filter) ) @@ -633,17 +752,14 @@ def run_job(record: SeqRecord) -> dict[str,Any]: record_data_list: list[dict[str, Any]] = pool.map(run_job, records) # Sort record results in lexicographical order - record_data_list = sorted(record_data_list, key=lambda x: x["record_id"]) + record_data_list = natsorted(record_data_list, key=lambda x: x["record_id"]) with ( open(feature_output_filename, "a") as feature_output_handle, open(aggregate_output_filename, "a") as aggregate_output_handle, ): - for record_data in record_data_list: - logging.info( - f"Record {record_data['record_id']} · Merging temporary output files..." - ) + logging.info(f"Record {record_data['record_id']} · Merging temporary output files...") with open(record_data["tmp_feature_output_file"]) as tmp_output_handle: feature_output_handle.write(tmp_output_handle.read()) @@ -655,13 +771,22 @@ def run_job(record: SeqRecord) -> dict[str,Any]: # Update the genome's aggregate counters from the record data aggregate counters for record_aggregate_type, record_aggregate_counter in record_data[ - "aggregate_counters" + "longest_isoform_aggregate_counters" ].items(): - genome_aggregate_counter: MultiCounter = genome_aggregate_counters[ + genome_aggregate_counter: MultiCounter = genome_longest_isoform_aggregate_counters[ record_aggregate_type ] genome_aggregate_counter.merge(record_aggregate_counter) + for record_aggregate_type, record_aggregate_counter in record_data["chimaera_aggregate_counters"].items(): + genome_aggregate_counter: MultiCounter = genome_chimaera_aggregate_counters[record_aggregate_type] + genome_aggregate_counter.merge(record_aggregate_counter) + + merge_aggregation_counter_dicts( + genome_all_isoforms_aggregate_counters, + record_data["all_isoforms_aggregate_counters"], + ) + # Update the genome's total counter from the record data total counter genome_total_counter.merge(record_data["total_counter"]) @@ -670,13 +795,23 @@ def run_job(record: SeqRecord) -> dict[str,Any]: aggregate_writer: AggregateFileWriter = AggregateFileWriter(aggregate_output_handle) # Write genomic counts - aggregate_writer.write_rows_with_data(".", ["."], ".", ".", genome_aggregate_counters) + aggregate_writer.write_rows_with_data( + ".", ["."], ".", ".", "longest_isoform", genome_longest_isoform_aggregate_counters + ) + aggregate_writer.write_rows_with_data( + ".", ["."], ".", ".", "all_isoforms", genome_all_isoforms_aggregate_counters + ) + aggregate_writer.write_rows_with_data( + ".", ["."], ".", ".", "chimaera", genome_chimaera_aggregate_counters + ) # Write the genomic total. A dummy dict needs to be created to use the `write_rows_with_data` method genomic_total_counter_dict: defaultdict[str, MultiCounter] = defaultdict( lambda: MultiCounter(genome_filter) ) genomic_total_counter_dict["."] = genome_total_counter - aggregate_writer.write_rows_with_data(".", ["."], ".", ".", genomic_total_counter_dict) + aggregate_writer.write_rows_with_data( + ".", ["."], ".", ".", "all_sites", genomic_total_counter_dict + ) logging.info("Program finished") diff --git a/bin/quickstats.py b/bin/quickstats.py new file mode 100644 index 0000000..4a8de83 --- /dev/null +++ b/bin/quickstats.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python + +import argparse +import numpy as np +from utils import SiteVariantData +from SiteFilter import SiteFilter +from typing import Optional +from site_variant_readers import ( + RNAVariantReader, + Reditools2Reader, + Reditools3Reader, + Jacusa2Reader +) + +def parse_cli_input() -> argparse.Namespace: + """Parse command line input""" + + parser: argparse.ArgumentParser = argparse.ArgumentParser(description="Site edits counter") + parser.add_argument( + "--input", "-i", + type=str, + required=True, + help="File of RNA variants per site" + ) + parser.add_argument( + "--format", "-f", + type=str, + choices=["reditools2", "reditools3", "jacusa2"], + required=True, + help="Format of the input file" + ) + + return parser.parse_args() + +if __name__ == "__main__": + args: argparse.Namespace = parse_cli_input() + filter: SiteFilter = SiteFilter(1, 0) + + with open(args.input) as input_handle: + reader: RNAVariantReader + match args.format: + case "reditools2": + reader = Reditools2Reader(input_handle) + case "reditools3": + reader = Reditools3Reader(input_handle) + case "jacusa2": + reader = Jacusa2Reader(input_handle) + case _: + raise Exception(f'Unimplemented format "{args.format}"') + + record_reads: np.typing.NDArray = np.zeros(5, dtype=np.int64) + record_sites: int = 0 + genome_reads: np.typing.NDArray = np.zeros(5, dtype=np.int64) + genome_sites: int = 0 + + svdata: Optional[SiteVariantData] = reader.read() + + if svdata: + current_record_id: str = svdata.seqid + np.copyto(record_reads, svdata.frequencies) + record_sites += int(np.any(svdata.frequencies > 0)) + + svdata = reader.read() + + while svdata: + if svdata.seqid != current_record_id: + genome_reads += record_reads + genome_sites += record_sites + print(f"Record {current_record_id}: covered sites {record_sites}\ttotal reads {record_reads.sum()}\tfrequencies {record_reads}") + record_reads[:] = 0 + record_sites = 0 + current_record_id = svdata.seqid + + record_reads += svdata.frequencies + record_sites += int(np.any(svdata.frequencies > 0)) + svdata = reader.read() + + genome_reads += record_reads + genome_sites += record_sites + + assert genome_sites >= record_sites + assert genome_reads.sum() >= record_reads.sum() + assert genome_reads.sum() >= genome_sites + + print(f"Record {current_record_id}: covered sites {record_sites}\ttotal reads {record_reads.sum()}\tfrequencies {record_reads}") + + print(f"Genome: covered sites {genome_sites}\ttotal reads {genome_reads.sum()}\tfrequencies {genome_reads}") diff --git a/bin/quickstats.sh b/bin/quickstats.sh new file mode 100644 index 0000000..b54978a --- /dev/null +++ b/bin/quickstats.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env sh + +input="" +format="" + +while getopts "i:f:" opt; do + case $opt in + i) + input="$OPTARG" + ;; + f) + format="$OPTARG" + ;; + esac +done + +if [ "$format" == "reditools" ]; then + grep -P ".+\t\d+\tA" $input | cut -f1,7 | tr -d "[]" | sed "s/, /\t/g" | awk '{sum[$1] += $4} END {for (key in sum) print "A->G", key, sum[key]}' + grep -P ".+\t\d+\tC" $input | cut -f1,7 | tr -d "[]" | sed "s/, /\t/g" | awk '{sum[$1] += $5} END {for (key in sum) print "C->T", key, sum[key]}' +elif [ "$format" == "jacusa2" ]; then + grep -P "\tA$" $input | cut -f1,7 | sed "s/,/\t/g" | awk '{sum[$1] += $4} END {for (key in sum) print "A->G", key, sum[key]}' + grep -P "\tC$" $input | cut -f1,7 | sed "s/,/\t/g" | awk '{sum[$1] += $5} END {for (key in sum) print "C->T", key, sum[key]}' +fi diff --git a/bin/utils.py b/bin/utils.py index 27ec66b..2d62869 100755 --- a/bin/utils.py +++ b/bin/utils.py @@ -1,11 +1,11 @@ -import numpy as np -from numpy.typing import NDArray -from dataclasses import dataclass from Bio.SeqFeature import SeqFeature, SimpleLocation, CompoundLocation -import itertools +from dataclasses import dataclass +from numpy.typing import NDArray from collections import deque -from typing import Optional from functools import reduce +from typing import Optional +import numpy as np +import itertools import logging logger = logging.getLogger(__name__) @@ -55,7 +55,7 @@ class SiteVariantData: strand: int coverage: int mean_quality: float - frequencies: NDArray[np.int32] + frequencies: NDArray[np.int64] score: float def overlaps(self: SimpleLocation, location: SimpleLocation) -> bool: @@ -73,7 +73,7 @@ def location_union(locations: list[SimpleLocation|CompoundLocation]) -> SimpleLo if len(locations) == 1: return locations[0] - comp_locations: CompoundLocation = reduce(lambda x, y: x + y, locations) + comp_locations: SimpleLocation|CompoundLocation = reduce(lambda x, y: x + y, locations) comp_locations.parts.sort(key=lambda part: (part.start, part.end), reverse=True) original_range = (comp_locations.parts[-1].start, max(map(lambda part: part.end, comp_locations.parts)))