diff --git a/singlem/summariser.py b/singlem/summariser.py index fff7de01..baddc31b 100644 --- a/singlem/summariser.py +++ b/singlem/summariser.py @@ -3,9 +3,7 @@ import extern from collections import OrderedDict import logging -import pandas import Bio -import pandas as pd import polars as pl import gzip @@ -351,6 +349,26 @@ def write_collapsed_paired_with_unpaired_otu_table(**kwargs): overall_df = None ar = None + def archive_schema(fields): + schema = {} + for field in fields: + if field in { + "read_names", + "nucleotides_aligned", + "read_unaligned_sequences", + "equal_best_hit_taxonomies", + }: + schema[field] = pl.Object + elif field == "num_hits": + schema[field] = pl.Int64 + elif field == "coverage": + schema[field] = pl.Float64 + elif field == "taxonomy_by_known?": + schema[field] = pl.Boolean + else: + schema[field] = pl.Utf8 + return schema + def read_archive_table(df, f, prev_ar): logging.debug("Reading archive table {} into RAM ..".format(a)) ar = ArchiveOtuTable.read(f) @@ -359,8 +377,9 @@ def read_archive_table(df, f, prev_ar): # fields = ar.fields # alignment_hmm_sha256s = ar.alignment_hmm_sha256s # singlem_package_sha256s = ar.singlem_package_sha256s - df = pandas.DataFrame(ar.data) - df.columns = ar.fields + df = pl.DataFrame( + ar.data, schema=archive_schema(ar.fields), orient="row" + ) else: if prev_ar.version != ar.version: raise Exception("Version mismatch between archives") @@ -370,9 +389,10 @@ def read_archive_table(df, f, prev_ar): raise Exception("Alignment HMM SHA256 mismatch between archives") elif prev_ar.singlem_package_sha256s != ar.singlem_package_sha256s: raise Exception("Singlem package SHA256 mismatch between archives") - df2 = pandas.DataFrame(ar.data) - df2.columns = prev_ar.fields - df = pd.concat([df, df2], ignore_index=True) + df2 = pl.DataFrame( + ar.data, schema=archive_schema(prev_ar.fields), orient="row" + ) + df = pl.concat([df, df2], how="vertical") return df, ar for a in archive_otu_tables: @@ -395,54 +415,66 @@ def read_archive_table(df, f, prev_ar): # Remove suffixes if set_sample_name is None: - def remove_suffix(s): - if s.endswith('_1'): - return s[:-2] - else: - return s - df['sample'] = df['sample'].apply(remove_suffix) + df = df.with_columns( + pl.col("sample").str.replace(r"_1$", "").alias("sample") + ) # Ensure that there is now only exactly 1 sample name - if set_sample_name is None and len(df['sample'].unique()) != 1: - raise Exception("Multiple sample names found: {}".format(', '.join(df['sample'].unique()))) - if len(df['taxonomy_by_known?'].unique()) != 1: - raise Exception("Multiple taxonomy_by_known found: {}".format(', '.join(df['taxonomy_by_known'].unique()))) - - def combine_rows(grouped1): - grouped = grouped1.reset_index() - max_row = grouped['num_hits'].idxmax() + if set_sample_name is None: + sample_names = df.select(pl.col("sample").unique()).to_series().to_list() + if len(sample_names) != 1: + raise Exception("Multiple sample names found: {}".format(', '.join(sample_names))) + taxonomy_by_known = df.select(pl.col("taxonomy_by_known?").unique()).to_series().to_list() + if len(taxonomy_by_known) != 1: + raise Exception("Multiple taxonomy_by_known found: {}".format(', '.join(map(str, taxonomy_by_known)))) + + df = df.sort(["sequence", "gene"]) + + collapsed_rows = [] + for grouped in df.partition_by( + ["sequence", "gene"], maintain_order=True, as_dict=False + ): + num_hits = grouped.get_column("num_hits").to_list() + max_row = num_hits.index(max(num_hits)) if set_sample_name: sample = set_sample_name else: - sample = grouped.iloc[0]['sample'] - tax_assignment_method = grouped.iloc[0]['taxonomy_assignment_method'] + sample = grouped.get_column("sample")[0] + tax_assignment_method = grouped.get_column("taxonomy_assignment_method")[0] if tax_assignment_method == QUERY_BASED_ASSIGNMENT_METHOD: - equal_best_hit_taxonomies = grouped.iloc[0]['equal_best_hit_taxonomies'] + equal_best_hit_taxonomies = grouped.get_column("equal_best_hit_taxonomies")[0] elif tax_assignment_method == DIAMOND_ASSIGNMENT_METHOD: - equal_best_hit_taxonomies = list(itertools.chain(*grouped['equal_best_hit_taxonomies'])) + equal_best_hit_taxonomies = list( + itertools.chain(*grouped.get_column("equal_best_hit_taxonomies").to_list()) + ) elif tax_assignment_method == None or tax_assignment_method == NO_ASSIGNMENT_METHOD: equal_best_hit_taxonomies = None else: raise Exception("Unexpected tax assignment method: {}".format(tax_assignment_method)) - return pd.DataFrame({ - 'gene':[grouped.iloc[0]['gene']], - 'sample':[sample], - 'sequence':[grouped.iloc[0]['sequence']], - 'num_hits':[sum(grouped['num_hits']),], - 'coverage':[sum(grouped['coverage']),], - 'taxonomy':[grouped.iloc[max_row]['taxonomy']], - 'read_names':[list(itertools.chain(*grouped['read_names']))], - 'nucleotides_aligned':[list(itertools.chain(*grouped['nucleotides_aligned']))], - 'taxonomy_by_known?':[grouped.iloc[0]['taxonomy_by_known?']], - 'read_unaligned_sequences':[list(itertools.chain(*grouped['read_unaligned_sequences']))], - 'equal_best_hit_taxonomies':[equal_best_hit_taxonomies], - 'taxonomy_assignment_method':[tax_assignment_method], - }) - transformed = df.groupby(['sequence','gene'], as_index=False).apply(combine_rows)[ArchiveOtuTable.FIELDS] - logging.info("Collapsed {} total OTUs into {} output OTUs".format(len(df), len(transformed))) + + collapsed_rows.append( + [ + grouped.get_column("gene")[0], + sample, + grouped.get_column("sequence")[0], + sum(num_hits), + sum(grouped.get_column("coverage").to_list()), + grouped.get_column("taxonomy")[max_row], + list(itertools.chain(*grouped.get_column("read_names").to_list())), + list(itertools.chain(*grouped.get_column("nucleotides_aligned").to_list())), + grouped.get_column("taxonomy_by_known?")[0], + list( + itertools.chain(*grouped.get_column("read_unaligned_sequences").to_list()) + ), + equal_best_hit_taxonomies, + tax_assignment_method, + ] + ) + + logging.info("Collapsed {} total OTUs into {} output OTUs".format(len(df), len(collapsed_rows))) logging.debug("Writing output table ..") - ar.data = transformed.values.tolist() + ar.data = collapsed_rows ar.write_to(output_table_io) # json.dump({"version": ar.version, # "alignment_hmm_sha256s": ar.alignment_hmm_sha256s,