Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 74 additions & 42 deletions singlem/summariser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
import extern
from collections import OrderedDict
import logging
import pandas
import Bio
import pandas as pd
import polars as pl
import gzip

Expand Down Expand Up @@ -351,6 +349,26 @@ def write_collapsed_paired_with_unpaired_otu_table(**kwargs):
overall_df = None
ar = None

def archive_schema(fields):
schema = {}
for field in fields:
if field in {
"read_names",
"nucleotides_aligned",
"read_unaligned_sequences",
"equal_best_hit_taxonomies",
}:
schema[field] = pl.Object
elif field == "num_hits":
schema[field] = pl.Int64
elif field == "coverage":
schema[field] = pl.Float64
elif field == "taxonomy_by_known?":
schema[field] = pl.Boolean
else:
schema[field] = pl.Utf8
return schema

def read_archive_table(df, f, prev_ar):
logging.debug("Reading archive table {} into RAM ..".format(a))
ar = ArchiveOtuTable.read(f)
Expand All @@ -359,8 +377,9 @@ def read_archive_table(df, f, prev_ar):
# fields = ar.fields
# alignment_hmm_sha256s = ar.alignment_hmm_sha256s
# singlem_package_sha256s = ar.singlem_package_sha256s
df = pandas.DataFrame(ar.data)
df.columns = ar.fields
df = pl.DataFrame(
ar.data, schema=archive_schema(ar.fields), orient="row"
)
else:
if prev_ar.version != ar.version:
raise Exception("Version mismatch between archives")
Expand All @@ -370,9 +389,10 @@ def read_archive_table(df, f, prev_ar):
raise Exception("Alignment HMM SHA256 mismatch between archives")
elif prev_ar.singlem_package_sha256s != ar.singlem_package_sha256s:
raise Exception("Singlem package SHA256 mismatch between archives")
df2 = pandas.DataFrame(ar.data)
df2.columns = prev_ar.fields
df = pd.concat([df, df2], ignore_index=True)
df2 = pl.DataFrame(
ar.data, schema=archive_schema(prev_ar.fields), orient="row"
)
df = pl.concat([df, df2], how="vertical")
return df, ar

for a in archive_otu_tables:
Expand All @@ -395,54 +415,66 @@ def read_archive_table(df, f, prev_ar):

# Remove suffixes
if set_sample_name is None:
def remove_suffix(s):
if s.endswith('_1'):
return s[:-2]
else:
return s
df['sample'] = df['sample'].apply(remove_suffix)
df = df.with_columns(
pl.col("sample").str.replace(r"_1$", "").alias("sample")
)

# Ensure that there is now only exactly 1 sample name
if set_sample_name is None and len(df['sample'].unique()) != 1:
raise Exception("Multiple sample names found: {}".format(', '.join(df['sample'].unique())))
if len(df['taxonomy_by_known?'].unique()) != 1:
raise Exception("Multiple taxonomy_by_known found: {}".format(', '.join(df['taxonomy_by_known'].unique())))

def combine_rows(grouped1):
grouped = grouped1.reset_index()
max_row = grouped['num_hits'].idxmax()
if set_sample_name is None:
sample_names = df.select(pl.col("sample").unique()).to_series().to_list()
if len(sample_names) != 1:
raise Exception("Multiple sample names found: {}".format(', '.join(sample_names)))
taxonomy_by_known = df.select(pl.col("taxonomy_by_known?").unique()).to_series().to_list()
if len(taxonomy_by_known) != 1:
raise Exception("Multiple taxonomy_by_known found: {}".format(', '.join(map(str, taxonomy_by_known))))

df = df.sort(["sequence", "gene"])

collapsed_rows = []
for grouped in df.partition_by(
["sequence", "gene"], maintain_order=True, as_dict=False
):
num_hits = grouped.get_column("num_hits").to_list()
max_row = num_hits.index(max(num_hits))
if set_sample_name:
sample = set_sample_name
else:
sample = grouped.iloc[0]['sample']
tax_assignment_method = grouped.iloc[0]['taxonomy_assignment_method']
sample = grouped.get_column("sample")[0]
tax_assignment_method = grouped.get_column("taxonomy_assignment_method")[0]
if tax_assignment_method == QUERY_BASED_ASSIGNMENT_METHOD:
equal_best_hit_taxonomies = grouped.iloc[0]['equal_best_hit_taxonomies']
equal_best_hit_taxonomies = grouped.get_column("equal_best_hit_taxonomies")[0]
elif tax_assignment_method == DIAMOND_ASSIGNMENT_METHOD:
equal_best_hit_taxonomies = list(itertools.chain(*grouped['equal_best_hit_taxonomies']))
equal_best_hit_taxonomies = list(
itertools.chain(*grouped.get_column("equal_best_hit_taxonomies").to_list())
)
elif tax_assignment_method == None or tax_assignment_method == NO_ASSIGNMENT_METHOD:
equal_best_hit_taxonomies = None
else:
raise Exception("Unexpected tax assignment method: {}".format(tax_assignment_method))
return pd.DataFrame({
'gene':[grouped.iloc[0]['gene']],
'sample':[sample],
'sequence':[grouped.iloc[0]['sequence']],
'num_hits':[sum(grouped['num_hits']),],
'coverage':[sum(grouped['coverage']),],
'taxonomy':[grouped.iloc[max_row]['taxonomy']],
'read_names':[list(itertools.chain(*grouped['read_names']))],
'nucleotides_aligned':[list(itertools.chain(*grouped['nucleotides_aligned']))],
'taxonomy_by_known?':[grouped.iloc[0]['taxonomy_by_known?']],
'read_unaligned_sequences':[list(itertools.chain(*grouped['read_unaligned_sequences']))],
'equal_best_hit_taxonomies':[equal_best_hit_taxonomies],
'taxonomy_assignment_method':[tax_assignment_method],
})
transformed = df.groupby(['sequence','gene'], as_index=False).apply(combine_rows)[ArchiveOtuTable.FIELDS]
logging.info("Collapsed {} total OTUs into {} output OTUs".format(len(df), len(transformed)))

collapsed_rows.append(
[
grouped.get_column("gene")[0],
sample,
grouped.get_column("sequence")[0],
sum(num_hits),
sum(grouped.get_column("coverage").to_list()),
grouped.get_column("taxonomy")[max_row],
list(itertools.chain(*grouped.get_column("read_names").to_list())),
list(itertools.chain(*grouped.get_column("nucleotides_aligned").to_list())),
grouped.get_column("taxonomy_by_known?")[0],
list(
itertools.chain(*grouped.get_column("read_unaligned_sequences").to_list())
),
equal_best_hit_taxonomies,
tax_assignment_method,
]
)

logging.info("Collapsed {} total OTUs into {} output OTUs".format(len(df), len(collapsed_rows)))

logging.debug("Writing output table ..")
ar.data = transformed.values.tolist()
ar.data = collapsed_rows
ar.write_to(output_table_io)
# json.dump({"version": ar.version,
# "alignment_hmm_sha256s": ar.alignment_hmm_sha256s,
Expand Down