diff --git a/plotting/compile_data.py b/plotting/compile_data.py new file mode 100644 index 0000000..ae66c50 --- /dev/null +++ b/plotting/compile_data.py @@ -0,0 +1,405 @@ +from glob import glob +from os.path import join, basename, relpath, sep +import yaml +import numpy as np +import pandas as pd +from tqdm import tqdm +from scipy.spatial.distance import pdist, squareform + + +def get_environments(fp_orb_basedir:str, settings) -> [str]: + """Iterated the base dir of Orb to find environments for which Orb has calculated metrics. + + Parameters + ---------- + fp_orb_basedir : str + The filepath to Orb's base data dir + settings : yaml-dict + General settings for plotting Orb graphs. Here, we need + a) the list of environments to skip + b) a potential pre-defined order of the found environments + + Returns + ------- + A list of environment names, i.e. sub-directories in Orb's base dir. + + """ + # collect environments from orb directory + environments = [basename(fp) + for fp in glob(join(fp_orb_basedir, '*')) + if basename(fp) not in settings['skip_environments']] + # sort environments according to YAML occurrence + environments = sorted(environments, key=lambda x: list(settings['labels']['environments'].keys()).index(x) if x in settings['labels']['environments'] else float('inf')) + + return environments + +def getdata_recovery(fp_orb_basedir:str, settings, verbose=True) -> pd.DataFrame: + """Collects data from Orb for contig recovery analysis. + + Parameters + ---------- + fp_orb_basedir : str + The filepath to Orb's base data dir + settings : yaml-dict + General settings for plotting Orb graphs. Here, we need + a) the environment to skip + verbose : boolean + Report progress on sys.stderr + Returns + ------- + A pandas.DataFrame that holds performance parameters for assemblers in the different environments. + """ + environments = get_environments(fp_orb_basedir, settings) + + data = [] + for environment in tqdm(environments, disable=not verbose, desc='Compiling data for contig recovery plot'): + # load orb data + orb = pd.read_csv(join(fp_orb_basedir, environment, 'mergedresults', '%s_all_scores.tsv' % environment), sep="\t", index_col=0) + # rename metrics to names used in publication + orb = orb.rename(index={k: c['label'] for k, c in settings['contig_classes'].items()}) + # add information about missed blocks + trueBlocks = pd.read_csv(join(fp_orb_basedir, environment, 'mergedataframessummaries', '%s_all_contigs.tsv' % environment), sep="\t", index_col=0).loc['count', ['Blocks']].sum() + orb.loc['missed blocks', :] = trueBlocks - orb.loc[[c['label'] for _, c in settings['contig_classes'].items() if c['class'] == 'good'], :].sum() + # select metrics for "recovery analysis" + orb = orb.loc[[c['label'] for _, c in settings['contig_classes'].items()], :] + # sort assembler by amount of good contigs + orb = orb[orb.loc[[c['label'] for _, c in settings['contig_classes'].items() if c['class'] == 'good']].sum().sort_values(ascending=True).index] + # use pretty label for assembler + orb = orb.rename(columns=settings['labels']['assemblers']) + # transform axis: rows=assembler, cols=metrics+metadata + orb = orb.T + orb['environment'] = environment + orb['recovery_rank'] = list(reversed(range(1, orb.shape[0] + 1))) + data.append(orb) + return pd.concat(data) + +def get_recovered_contigs(fp_orb_basedir:str, settings, verbose=True): + mapping_col_names = [ + "Query sequence name", + "Query sequence length", + "Query start coordinate", + "Query end coordinate", + "same strand", + "block_id", #"Target sequence name", + "Target sequence length", + "Target start coordinate on the original strand", + "Target end coordinate on the original strand", + "Number of matching bases in the mapping", + "Number bases, including gaps, in the mapping", + "Mapping quality", "1", "2", "3", "4", "5", "6"] + + # which environments to plot and in which order + environments = get_environments(fp_orb_basedir, settings) + + recovered_contigs = dict() + for environment in tqdm(environments, disable=not verbose, desc='Compiling data for gene recovery plot'): + recovered_contigs[environment] = dict() + for fp_category in glob(join(fp_orb_basedir, environment, 'categorizecontigs', '*_contigs_categorised.tsv')): + assembler = basename(fp_category).split('_contigs_categorised.tsv')[0] + + contig_classes = pd.read_csv(fp_category, sep="\t", index_col=0) + contig_classes.index = list(map(str, contig_classes.index)) + + contig_mappings = pd.read_csv( + join(fp_orb_basedir, environment, 'minimap2', '%s_mapping.tsv' % assembler), sep="\t", header=None, names=mapping_col_names, dtype={'Query sequence name': str} + ).sort_values(by=['Mapping quality', 'Number of matching bases in the mapping'], ascending=[False, False] # sort hits by mapping quality AND number of involved matching bases + ).groupby(['Query sequence name']).head(1 # for every contig: pick only best hit + ).groupby('block_id').head(1).merge( # for every block: pick only best hit + contig_classes, left_on=['Query sequence name'], right_index=True, how='left') # merge Timo's contig categorization + # add a binary classification: good and missed ... + contig_mappings['class'] = contig_mappings['category'].apply(lambda x: 'good' if x in [col for col, c in settings['contig_classes'].items() if c['class'] == 'good'] else 'missed') + # ... and keep only those in class "good" + contig_mappings = contig_mappings[contig_mappings['class'] == 'good'] + + # derive gene name from block_id + contig_mappings['gene_name'] = contig_mappings['block_id'].apply(lambda x: x.split('_block')[0]) + + recovered_contigs[environment][assembler] = contig_mappings + + return recovered_contigs + +def getdata_gene_recovery(fp_orb_basedir:str, settings, full_gene_table=False, verbose=True) -> pd.DataFrame: + recovered_contigs = get_recovered_contigs(fp_orb_basedir, settings, verbose=True) + # recovered_contigs = dict() + # for environment in tqdm(environments, disable=not verbose, desc='Compiling data for gene recovery plot'): + # recovered_contigs[environment] = dict() + # for fp_category in glob(join(fp_orb_basedir, environment, 'categorizecontigs', '*_contigs_categorised.tsv')): + # assembler = basename(fp_category).split('_contigs_categorised.tsv')[0] + + # contig_classes = pd.read_csv(fp_category, sep="\t", index_col=0) + # contig_classes.index = list(map(str, contig_classes.index)) + + # contig_mappings = pd.read_csv( + # join(fp_orb_basedir, environment, 'minimap2', '%s_mapping.tsv' % assembler), sep="\t", header=None, names=mapping_col_names, dtype={'Query sequence name': str} + # ).sort_values(by=['Mapping quality', 'Number of matching bases in the mapping'], ascending=[False, False] # sort hits by mapping quality AND number of involved matching bases + # ).groupby(['Query sequence name']).head(1 # for every contig: pick only best hit + # ).groupby('block_id').head(1).merge( # for every block: pick only best hit + # contig_classes, left_on=['Query sequence name'], right_index=True, how='left') # merge Timo's contig categorization + # # add a binary classification: good and missed ... + # contig_mappings['class'] = contig_mappings['category'].apply(lambda x: 'good' if x in [col for col, c in settings['contig_classes'].items() if c['class'] == 'good'] else 'missed') + # # ... and keep only those in class "good" + # contig_mappings = contig_mappings[contig_mappings['class'] == 'good'] + + # # derive gene name from block_id + # contig_mappings['gene_name'] = contig_mappings['block_id'].apply(lambda x: x.split('_block')[0]) + + # recovered_contigs[environment][assembler] = contig_mappings + + # which environments to plot and in which order + environments = get_environments(fp_orb_basedir, settings) + + recovered_genes = [] + for environment in tqdm(environments, disable=not verbose, desc='Compute sets of unique/shared/core genes for gene recovery plot'): + pd.set_option('future.no_silent_downcasting', True) + features = pd.concat( + [pd.Series( + index=list(v['gene_name'].unique()), + data=True, + name=k, + ).rename_axis('gene_name') + for k, v in recovered_contigs[environment].items() + ], axis=1).replace(np.nan, False).astype(bool) + if full_gene_table: + recovered_genes.append((environment, features.rename(index=settings['labels']['assemblers']))) + continue + + del features['idba_mt'] # as this is too close to idba_tran + # number genes found by only one assembler + upset = features[features.sum(axis=1) == 1].unstack().unstack().sum(axis=1) + # number genes found by all assemblers + upset.loc['core'] = features[features.sum(axis=1) == len(features.columns)].shape[0] + # number genes found by two or more, but not all assemblers + upset.loc['shared'] = features[(features.sum(axis=1) < len(features.columns)) & (features.sum(axis=1) > 1)].shape[0] + # total number genes + upset.loc['total'] = features.shape[0] + upset.name = environment + recovered_genes.append(upset) + if full_gene_table: + return {environment: feature_table for (environment, feature_table) in recovered_genes} + recovered_genes = pd.concat(recovered_genes, axis=1) + recovered_genes.index.name = 'assembler' + + # pretty assembler names + recovered_genes = recovered_genes.rename(index=settings['labels']['assemblers']) + + return recovered_genes + +def getdata_runtime_memory(fp_caviar_basedir:str, settings, verbose=True): + fields_to_collect = ['Maximum resident set size (kbytes): ', 'User time (seconds): ', 'System time (seconds): '] + data = [] + for fp_time in tqdm(glob('%s/**/*_time_log.txt' % fp_caviar_basedir, recursive=True), disable=not verbose, desc='Compiling data for runtime/memory footprint plot'): + # derive environment and assembler from filenames + environment = relpath(fp_time, fp_caviar_basedir).split(sep)[0] + assembler = (basename(fp_time)[len(environment)+1:-1*len('_time_log.txt')]) + + # parse /usr/bin/time verbose output + with open(fp_time, 'r') as f: + for line in f.readlines(): + for field in fields_to_collect: + if field in line: + value = float(line.split(field)[-1].strip()) + data.append((environment, assembler, field, value)) + data = pd.DataFrame(data, columns=['environment', 'assembler', 'type', 'value']) + + # combine user and system time + cmb_time = (data[data['type'] == 'User time (seconds): '].set_index(['environment', 'assembler'])['value'] + data[data['type'] == 'System time (seconds): '].set_index(['environment', 'assembler'])['value']).reset_index() + cmb_time['type'] = 'CPU time (seconds)' + + timemem = pd.concat([data, cmb_time]) + timemem = timemem[timemem['assembler'] != 'convert_to_fasta'] # this is a pre-processing for some of the assembler as they not always except fastQ.gz + + # combine runtime for oases, i.e. velveth + velvetg + oases + assembler = 'oases' + oases_combined = [] + for field in fields_to_collect + ['CPU time (seconds)']: + fct_combine = np.sum + if field == 'Maximum resident set size (kbytes): ': + fct_combine = np.max + combined = timemem.set_index(['environment', 'assembler', 'type']).loc[:, ['velveth', 'velvetg', assembler], field, :].reset_index().groupby('environment')['value'].apply(fct_combine).to_frame().reset_index() + combined['type'] = field + combined['assembler'] = assembler + oases_combined.append(combined) + + timemem = pd.concat([ + timemem[~timemem['assembler'].isin(['velveth', 'velvetg', assembler])], # result without part of oases + pd.concat(oases_combined)] # combined oases results + ) + # pretty assembler names + timemem['assembler'] = timemem['assembler'].apply(lambda x: settings['labels']['assemblers'].get(x, x)) + + # pretty environment names + timemem['environment'] = timemem['environment'].apply(lambda x: settings['labels']['environments'].get(x, x)) + + return timemem + +def getdata_DEgenes(fp_orb_basedir:str, settings, verbose=True): + recovered_contigs = get_recovered_contigs(fp_orb_basedir, settings, verbose=True) + + # which environments to plot and in which order + environments = get_environments(fp_orb_basedir, settings) + + confusion = [] + for environment in tqdm(environments, disable=not verbose, desc='Compiling data for DE gene plot'): + fp_truth = join(fp_orb_basedir, environment, 'refdeseq2', '%s_DESeq2_full_table.tsv' % environment) + truth = pd.read_csv(fp_truth, sep="\t", index_col=0) + truth.index = list(map(str, truth.index)) + + # flag all genes which are DE + truth = truth['padj'].apply(lambda x: x < 0.05).rename('truth') + + for fp_category in sorted(glob(join(fp_orb_basedir, environment, 'deseq2', '*_DESeq2_full_table.tsv'))): + assembler = basename(fp_category).split('_DESeq2_full_table.tsv')[0] + prediction = pd.read_csv(fp_category, sep="\t", index_col=0) + prediction.index = list(map(str, prediction.index)) + + # only keep contigs that have been identified as being statistically significant + # flag all remaining contigs as DE prediction + prediction = prediction['padj'].apply(lambda x: x < 0.05).rename('prediction') + + pd.set_option('future.no_silent_downcasting', True) + # combine assembled contigs with DE truth and prediction and count occurences + conv = recovered_contigs[environment][assembler].merge( + truth, left_on='gene_name', right_index=True, how='outer').merge( + prediction, left_on='Query sequence name', right_index=True, how='outer').groupby('gene_name').head(1).fillna(False).groupby( + ['truth', 'prediction']).size() + + # re-structure convolution data + conv = conv.rename('num_genes').to_frame() + # give more speaking names + conv['class'] = list(map(lambda row: {(False, False): 'True Negative', + (False, True): 'False Positive', + (True, False): 'False Negative', + (True, True): 'True Positive'}.get(row[0], row[0]), conv.iterrows())) + # add in environment + assembler info + conv['environment'] = settings['labels']['assemblers'].get(environment, environment) + conv['assembler'] = settings['labels']['assemblers'].get(assembler, assembler) + conv = conv.reset_index() + confusion.append(conv) + + # re-structure into one dataframe + confusion = pd.concat(confusion, axis=0).set_index(['environment', 'assembler', 'class']).sort_index() + confusion['rank'] = np.nan + + # add rank information to return dataframe + for environment in confusion.index.levels[0]: + order = list(reversed( + pd.pivot_table(data=confusion.loc[environment, :], index='assembler', columns='class', values='num_genes', aggfunc="sum").sort_values( + by= ['True Positive', 'False Positive', 'False Negative'], + ascending=[False, True, True]).index)) + confusion.loc[confusion.loc[environment, order, :].index, 'rank'] = [ + rank + for rank in list(reversed(range(1, len(order) + 1))) + for i in range(confusion.reset_index()['class'].unique().shape[0])] + + return confusion, recovered_contigs + +def getdata_DEorthogroups(fp_orb_basedir:str, fp_marbel_basedir:str, fp_ogtruth_basedir:str, settings, verbose=True): + recovered_contigs = get_recovered_contigs(fp_orb_basedir, settings, verbose=True) + + # which environments to plot and in which order + environments = get_environments(fp_orb_basedir, settings) + + confusion = [] + for environment in tqdm(environments, disable=not verbose, desc='Compiling data for DE orthogroup plot'): + # obtain orthogroup information about genes + genes = pd.read_csv(join(fp_marbel_basedir, '%s_microbiome' % environment, "summary", "gene_summary.csv"), sep=",").set_index('gene_name') + + fp_truth = join(fp_ogtruth_basedir, environment, '%s_DESeq2_full_table.tsv' % environment) + truth = pd.read_csv(fp_truth, sep="\t", index_col=0) + truth.index = list(map(str, truth.index)) + + # flag all genes which are DE + truth = truth['padj'].apply(lambda x: x < 0.05).rename('truth') + + for fp_assembler in sorted(glob(join(fp_ogtruth_basedir, environment, '%s_*_DESeq2_full_table.tsv' % environment))): + assembler = basename(fp_assembler).split('_DESeq2_full_table.tsv')[0].split('%s_' % environment)[-1] + prediction = pd.read_csv(fp_assembler, sep="\t", index_col=0) + prediction.index = list(map(str, prediction.index)) + + prediction = prediction['padj'].apply(lambda x: x < 0.05).rename('prediction') + + pd.set_option('future.no_silent_downcasting', True) + conv = recovered_contigs[environment][assembler].merge( + genes[['orthogroup']], left_on='gene_name', right_index=True, how='left').merge( + truth, left_on='orthogroup', right_index=True, how='outer').merge( + prediction, left_on='orthogroup', right_index=True, how='outer').groupby( + 'orthogroup').head(1).fillna(False).groupby( + ['truth', 'prediction']).size() + + # re-structure convolution data + conv = conv.rename('num_genes').to_frame() + # give more speaking names + conv['class'] = list(map(lambda row: {(False, False): 'True Negative', + (False, True): 'False Positive', + (True, False): 'False Negative', + (True, True): 'True Positive'}.get(row[0], row[0]), conv.iterrows())) + # add in environment + assembler info + conv['environment'] = settings['labels']['assemblers'].get(environment, environment) + conv['assembler'] = settings['labels']['assemblers'].get(assembler, assembler) + conv = conv.reset_index() + confusion.append(conv) + + # re-structure into one dataframe + confusion = pd.concat(confusion, axis=0).set_index(['environment', 'assembler', 'class']).sort_index() + confusion['rank'] = np.nan + + # add rank information to return dataframe + for environment in confusion.index.levels[0]: + order = list(reversed( + pd.pivot_table(data=confusion.loc[environment, :], index='assembler', columns='class', values='num_genes', aggfunc="sum").sort_values( + by= ['True Positive', 'False Positive', 'False Negative'], + ascending=[False, True, True]).index)) + confusion.loc[confusion.loc[environment, order, :].index, 'rank'] = [ + rank + for rank in list(reversed(range(1, len(order) + 1))) + for i in range(confusion.reset_index()['class'].unique().shape[0])] + + return confusion, recovered_contigs + + +def getdata_DEvennOrtho(fp_orb_basedir:str, fp_ogtruth_basedir:str, fp_marbel_basedir:str, settings, verbose=True) -> pd.DataFrame: + # which environments to plot and in which order + environments = get_environments(fp_orb_basedir, settings) + + data = dict() + for environment in tqdm(environments, disable=not verbose, desc='Compiling data for DE Venn diagram'): + fp_truth = join(fp_orb_basedir, environment, 'refdeseq2', '%s_DESeq2_full_table.tsv' % environment) + truth = pd.read_csv(fp_truth, sep="\t", index_col=0) + truth.index = list(map(str, truth.index)) + + # obtain orthogroup information about genes + genes = pd.read_csv(join(fp_marbel_basedir, '%s_microbiome' % environment, "summary", "gene_summary.csv"), sep=",").set_index('gene_name') + # add a column that reports if the gene is part of a single or multi gene orthogroup + genes = genes.merge(genes.groupby('orthogroup').size().apply(lambda x: 'single_gene_OG' if x == 1 else 'multi_gene_OG').rename('OGsize'), + left_on='orthogroup', right_index=True, how='left') + truth = truth.merge(genes[['orthogroup', 'OGsize']], + left_index=True, right_index=True, how='left') + + # obtain DE truth if genes are collapsed into orthogroups + fp_truth_OG = join(fp_ogtruth_basedir, environment, '%s_DESeq2_full_table.tsv' % environment) + truthOG = pd.read_csv(fp_truth_OG, sep="\t", index_col=0) + truthOG.index = list(map(str, truthOG.index)) + truth = truth.merge(truthOG, left_on='orthogroup', right_index=True, how='left', suffixes=('_gene', '_orthogroup')) + + # flag elements as DE + for typ in ['gene', 'orthogroup']: + truth['DE%s' % typ] = truth['padj_%s' % typ] < 0.05 + + data[environment] = truth + + return data + + +def getdata_rnaquast(fp_orb_basedir:str, fp_quast_basedir:str, settings, verbose=True) -> pd.DataFrame: + # which environments to plot and in which order + environments = get_environments(fp_orb_basedir, settings) + + data = [] + for environment in tqdm(environments, disable=not verbose, desc='Compile RNAquast data'): + quast = pd.read_csv(join(fp_quast_basedir, environment, "short_report.tsv"), sep="\t", index_col=0) + quast.index.name = 'metric' + quast.columns = list(map(lambda x: settings['labels']['assemblers'].get(x.split('_contigs')[0], x), quast.columns)) + quast = quast.stack().reset_index().rename(columns={'level_1': 'assembler', 0: 'score'}) + quast['environment'] = environment + data.append(quast) + return pd.concat(data).set_index(['environment', 'assembler', 'metric']) diff --git a/plotting/plot.py b/plotting/plot.py new file mode 100644 index 0000000..0fe4bd8 --- /dev/null +++ b/plotting/plot.py @@ -0,0 +1,503 @@ +from glob import glob +from os.path import join, basename +import yaml +import numpy as np +import matplotlib.pyplot as plt +import pandas as pd +import matplotlib.colors as mcolors +from matplotlib import patches as mpatches +from matplotlib.lines import Line2D +from matplotlib_venn import venn2 +import colorsys +import seaborn as sns +from compile_data import get_environments, getdata_recovery, getdata_gene_recovery, getdata_runtime_memory, getdata_DEgenes, getdata_DEvennOrtho, getdata_DEorthogroups, get_recovered_contigs, getdata_rnaquast +from tqdm import tqdm +from scipy.cluster.hierarchy import linkage, leaves_list, dendrogram +from scipy.spatial.distance import pdist, squareform +from skbio.stats.distance import DistanceMatrix + + +def plot_recovery(fp_orb_basedir, settings, num_columns:int=3, verbose=True): + """Plots contig recovery. + + Parameters + ---------- + fp_orb_basedir : str + The filepath to Orb's base data dir + settings : yaml-dict + General settings for plotting Orb graphs. Here, we need + a) the order of the environments + b) pretty labels for environments and assembler + c) category information about contigs + num_columns : int + Maximal number of panels in a row. + verbose : boolean + Report progress on sys.stderr + + Returns + ------- + plt figure of the multi-pabel plot. + """ + # which environments to plot and in which order + environments = get_environments(fp_orb_basedir, settings) + # load data + data_recovery = getdata_recovery(fp_orb_basedir, settings, verbose) + + fig, axes = plt.subplots( + int(np.ceil(len(environments) / num_columns)), + num_columns * 2, figsize=(2 * num_columns * 4, np.ceil(len(environments) / num_columns) * 5), + gridspec_kw={"wspace": 0.31, "hspace": 0.3}) + + for i, environment in tqdm(enumerate(environments), disable=not verbose, desc='Drawing panels for contig recovery plot'): + orb = data_recovery[data_recovery['environment'] == environment] + + # bad contigs + ax_bad = axes[i // num_columns, (i % num_columns) * 2] + ax_bad.invert_xaxis() + orb.loc[:, [c['label'] for _, c in settings['contig_classes'].items() if c['class'] == 'bad']].plot(kind='barh', stacked=True, ax=ax_bad, color={c['label']: c['color'] for _, c in settings['contig_classes'].items()}) + ax_bad.set_xlabel("number contigs") + ax_top_bad = ax_bad.twiny() + ax_top_bad.xaxis.set_label_position('top') + ax_top_bad.set_xticks([]) + ax_top_bad.set_xlabel("weak") + ax_bad.xaxis.set_label_coords(1, -0.08) + ax_bad.text(-0.5, 1.05, chr(97+i), transform=ax_bad.transAxes, fontsize=16, fontweight='bold',) + + # good contigs + ax_good = axes[i // num_columns, (i % num_columns) * 2 + 1] + orb.loc[:, [c['label'] for _, c in settings['contig_classes'].items() if c['class'] != 'bad']].plot(kind='barh', stacked=True, ax=ax_good, color={c['label']: c['color'] for _, c in settings['contig_classes'].items()}) + ax_good.set_yticks([]) + ax_good.set_xlabel("robust") + ax_good.xaxis.set_label_position('top') + + # concat right (=good) axis directly adjacent to left (=bad) axis + ax_good.set_position([ + ax_bad.get_position().x1, + ax_good.get_position().y0, + ax_good.get_position().width, + ax_good.get_position().height + ]) + ax_bad.set_title(settings['labels']['environments'].get(environment, environment), loc='right', horizontalalignment='center') + + # one legend for all panels + if i+1 == len(environments): + handles = ax_good.get_legend_handles_labels()[0] + list(reversed(ax_bad.get_legend_handles_labels()[0])) + labels = ax_good.get_legend_handles_labels()[1] + list(reversed(ax_bad.get_legend_handles_labels()[1])) + ax_good.legend(handles, labels, ncol=8, bbox_to_anchor=(-0.1, -0.20)) + else: + ax_good.legend().remove() + ax_bad.legend().remove() + + return fig + + +# # above function should be called as following: +# # 1) load plotting settings +# with open("/homes/sjanssen/Git/jlab/orb/plotting/style.yaml", "r") as f: +# settings = yaml.safe_load(f) +# # 2) then call plot_recovery with the filepath to Orb's base dir and the loaded settings +# _ = plot_recovery('/vol/jlab/tlin/all_project/nf_results/orb', settings) + +def _color_desaturate(color, saturation=0.75): + h, s, v = colorsys.rgb_to_hsv(*mcolors.to_rgb(color)) + return colorsys.hsv_to_rgb(h, s, v*saturation) + +def plot_gene_recovery(fp_orb_basedir, settings, num_columns:int=3, verbose=True): + # which environments to plot and in which order + environments = get_environments(fp_orb_basedir, settings) + # load data + recovered_genes = getdata_gene_recovery(fp_orb_basedir, settings) + + fig, axes = plt.subplots( + int(np.ceil(len(environments) / num_columns)), + num_columns, figsize=(num_columns * 5, np.ceil(len(environments) / num_columns) * 3), + gridspec_kw={"wspace": 0.6, "hspace": 0.4} + ) + + palette = {'core': 'gold', 'shared': 'cyan', 'total': 'lightgray'} + assemblers = [x for x in recovered_genes.index if x not in palette.keys()] + palette.update({assembler: sns.color_palette()[0] for assembler in assemblers}) + + for i, environment in tqdm(enumerate(environments), disable=not verbose, desc='Drawing panels for gene recovery plot'): + ax = axes[i // num_columns, i % num_columns] + + order = list(recovered_genes.loc[assemblers, environment].sort_values(ascending=False).index) + ['core', 'shared'] + sns.barplot(data=recovered_genes[environment].to_frame().reset_index(), orient='h', y='assembler', hue='assembler', x=environment, ax=ax, palette=palette, order=order) + ax.axvline(x=recovered_genes.loc['core', environment], color=palette['core']) + ax.set_ylabel("") + ax.set_xlabel("number recovered genes") + ax.set_title(settings['labels']['environments'].get(environment, environment)) + ax.set_xscale('log') + + if i+1 == len(environments): + ax.legend(handles=[ + mpatches.Patch(color=palette[assemblers[0]], label='exclusively recovered'), + mpatches.Patch(color=_color_desaturate(palette['core'], 0.75), label='recovered by all'), + mpatches.Patch(color=_color_desaturate(palette['shared'], 0.75), label='recovered by some')], + bbox_to_anchor=(-0.1, -0.25), ncols=3) + + # panel labels + ax.text(-0.43, 1.05, chr(97+i), transform=ax.transAxes, fontsize=16, fontweight='bold',) + + return fig + + +def plotTimeMemory(fp_caviar_basedir:str, settings, verbose=True): + timemem = getdata_runtime_memory(fp_caviar_basedir, settings) + + fig, axes = plt.subplots(1, 2, figsize=(10, 4), gridspec_kw={"wspace": 0.5}) + + for ax, (pType, factor, label, title) in zip(axes, [ + ('CPU time (seconds)', 3600, 'CPU hours', 'Runtime'), + ('Maximum resident set size (kbytes): ', 1024**2, 'RAM in GB', 'Memory footprint')]): + plotdata = timemem[timemem['type'] == pType].copy() + plotdata['unit'] = plotdata['value'] / factor + plotdata = plotdata.sort_values(by='unit') + sns.boxplot(data=plotdata, x='unit', y='assembler', orient='h', ax=ax, color='lightgray') + sns.stripplot(data=plotdata, x='unit', y='assembler', orient='h', ax=ax, hue='environment', hue_order=settings['labels']['environments']) + ax.set_xlabel(label) + ax.set_ylabel("") + ax.set_title(title) + if ax != axes[-1]: + ax.legend().remove() + else: + ax.axvline(x=64, linestyle='-.', color='gray', zorder=-1, label="64 GB laptop") + ax.legend(bbox_to_anchor=(1.1, -0.15), ncols=7) + + # panel labels + for i, ax in enumerate(axes): + ax.text(-0.43, 1.05, chr(97+i), transform=ax.transAxes, fontsize=16, fontweight='bold',) + + return fig + + +def get_rank_shifts(ranksA: pd.Series, ranksB: pd.Series, mergeAssembler={'idba': ['IDBA-MT', 'IDBA-tran']}): + """Given two rankings for the same set of assembler, compute the difference in rank positions. + + Note: optionally treat different assemblers as identical, e.g. IDBA-MT and IDBA-tran + """ + def _merge_assembler(ranks, mergeAssembler): + # map different assemblers to same label, e.g. for IDBA-MT and IDBA-tran, because we don't see differences and want to keep them as one + ranks.name = 'old_ranks' + ranks = ranks.to_frame() + ranks['merged'] = list(map(lambda x: {label: mergelabel for mergelabel, assemblers in mergeAssembler.items() for label in assemblers}.get(x, x), ranks.index)) + return ranks + + def _rerank_nomergedassemblers(ranks): + # assign new ranks according to the given ones, but assign same rank to same assembler name + rank = 0 + currAss = None + newranks = [] + ranks = ranks.sort_values('old_ranks') + for assembler in ranks['merged'].values: + if currAss != assembler: + rank +=1 + currAss = assembler + newranks.append(rank) + ranks['merged_ranks'] = newranks + return ranks['merged_ranks'] + + cmp = pd.concat([_rerank_nomergedassemblers(_merge_assembler(ranksA, mergeAssembler)).rename('rank_reference'), + _rerank_nomergedassemblers(_merge_assembler(ranksB, mergeAssembler)).rename('rank_other')], axis=1) + # should an assembler be missing in one of the ranks, assign worst+1 rank to it + for col in cmp.columns: + cmp[col] = cmp[col].fillna(cmp[col].max() + 1) + cmp['shift'] = cmp.iloc[:, 0] - cmp.iloc[:, 1] + + def _create_label(row): + if row['shift'] > 0: + return '%s ⬆+%i' % (row.name, row['shift']) + elif row['shift'] < 0: + return '%s ⬇%i' % (row.name, row['shift']) + else: + return '%s %i' % (row.name, row['shift']) + cmp['label'] = cmp.apply(_create_label, axis=1) + + def _create_color(row): + if row['shift'] > 0: + return 'darkgreen' + elif row['shift'] < 0: + return 'darkorange' + else: + return 'black' + cmp['color'] = cmp.apply(_create_color, axis=1) + + return cmp + + +def plot_DEgenes(fp_orb_basedir, settings, forOrthogroups=False, fp_marbel_basedir:str=None, fp_ogtruth_basedir:str=None, num_columns:int=3, verbose=True): + # which environments to plot and in which order + environments = get_environments(fp_orb_basedir, settings) + + # load data + DEfeatures, _ = getdata_DEgenes(fp_orb_basedir, settings, verbose) + if forOrthogroups is False: + rank_data = getdata_recovery(fp_orb_basedir, settings, verbose) + else: + rank_data = DEfeatures.loc[:, :, 'True Positive']['rank'].reset_index().rename(columns={'rank': 'recovery_rank'}).set_index('assembler').copy() + DEfeatures, _ = getdata_DEorthogroups(fp_orb_basedir, fp_marbel_basedir, fp_ogtruth_basedir, settings) + + fig, axes = plt.subplots( + int(np.ceil(len(environments) / num_columns)), + num_columns, figsize=(num_columns * 7, np.ceil(len(environments) / num_columns) * 4), + gridspec_kw={"hspace": 0.31, "wspace": 0.5}) + BARWIDTH=0.8 + + palette = {'True Positive': '#238cc3', + 'False Positive': '#5d5e60', + 'False Negative': '#c06364'} + for i, environment in tqdm(enumerate(environments), disable=not verbose, desc='Drawing panels for DE plot'): + ax = axes[i // num_columns, i % num_columns] + + order = list(DEfeatures.loc[environment, :].reset_index().groupby('assembler').head(1).set_index('assembler')['rank'].sort_values(ascending=False).index) + for y, assembler in enumerate(order): + cls = 'True Positive' + pos_TP = DEfeatures.loc[environment, assembler, cls]['num_genes'] + ax.add_patch(plt.Rectangle((0, y), pos_TP, BARWIDTH, facecolor=palette[cls], edgecolor='white', linewidth=1, label=cls if y == 0 else None)) + + cls = 'False Positive' + pos_FP = DEfeatures.loc[environment, assembler, cls]['num_genes'] + ax.add_patch(plt.Rectangle((pos_TP, y), pos_FP, BARWIDTH, facecolor=palette[cls], edgecolor='white', linewidth=1, label=cls if y == 0 else None)) + + cls = 'False Negative' + pos_FN = DEfeatures.loc[environment, assembler, cls]['num_genes'] + ax.add_patch(plt.Rectangle((0, y), -1 * pos_FN, BARWIDTH, facecolor=palette[cls], edgecolor='white', linewidth=1, label=cls if y == 0 else None)) + ax.set_ylim((-1 * (1 - BARWIDTH), len(order))) + + ranks = get_rank_shifts(rank_data[rank_data['environment'] == environment]['recovery_rank'], + pd.Series(index=order, data=reversed(range(1, len(order) + 1)))) + ax.set_yticks(list(map(lambda x: x + BARWIDTH/2, range(len(order)))), ranks.loc[order, 'label'].values) + for tick_label, color in zip(ax.get_yticklabels(), ranks.loc[order, 'color'].values): + tick_label.set_color(color) + + ax.set_title(settings['labels']['environments'].get(environment, environment)) + ax.set_xlabel('number genes') + + num_positives = DEfeatures.loc[environment, :].reset_index().set_index('truth').loc[True].groupby('assembler')['num_genes'].sum().iloc[0] + ax.axvline(x=num_positives, color=palette['True Positive'], label='Positive') + maxX = (DEfeatures.loc[environment, :, 'True Positive']['num_genes'] + DEfeatures.loc[environment, :, 'False Positive']['num_genes']).max() + for (ex_forOrtho, ex_env, ex_ass) in [(False, 'seawater', 'Trinity'), + (False, 'freshwater', 'Trinity'), + (False, 'healthy_gut', 'dbg'), + ]: + if (ex_forOrtho == forOrthogroups) and (environment == ex_env): + pdata = DEfeatures.loc[environment, [ass for ass in DEfeatures.loc[environment, :].index.levels[0] if ass != ex_ass], :] + maxX = (pdata.loc[environment, :, 'True Positive']['num_genes'] + pdata.loc[environment, :, 'False Positive']['num_genes']).max() + + ax.text(max(1, num_positives, maxX), list(order).index(ex_ass) + 0.35, '%i' % DEfeatures.loc[environment, ex_ass, 'False Positive']['num_genes'], + verticalalignment='center', horizontalalignment='right', color='white') + + # panel labels + ax.text(-0.43, 1.05, chr(97+i), transform=ax.transAxes, fontsize=16, fontweight='bold',) + + ax.set_xlim(( + -1.1 * max(1, DEfeatures.loc[environment, :, 'False Negative']['num_genes'].max()), + 1.1 * max(1, num_positives, maxX))) + + if i+1 == len(environments): + ax.legend(handles=[mpatches.Patch(color=palette[cls], label=cls) for cls in palette.keys()] + \ + [Line2D([0], [0], color=palette['True Positive'], lw=2, label='Positive')], + #title='Category', + bbox_to_anchor=(-0.4, -0.25), + ncols=4) + return fig + + +def plot_DEvennOrtho(fp_orb_basedir:str, fp_ogtruth_basedir:str, fp_marbel_basedir:str, settings, num_columns:int=3, verbose=True): + # which environments to plot and in which order + environments = get_environments(fp_orb_basedir, settings) + + # get data + truth = getdata_DEvennOrtho(fp_orb_basedir, fp_ogtruth_basedir, fp_marbel_basedir, settings) + + fig, axes = plt.subplots( + 3, + len(environments), figsize=(len(environments) * 3, len(environments)/2 * 3), + #gridspec_kw={"wspace": 0.6, "hspace": 0.4} + ) + palette = {'DE orthogroups': 'blue', + 'DE genes contained in orthogroups': 'orange'} + labels = ["", ""] + + reordered_environments = [] + for start in range(num_columns): + reordered_environments.extend(environments[start::num_columns]) + + for i, environment in tqdm(enumerate(reordered_environments), disable=not verbose, desc='Drawing panels for DE Venn diagrams'): + ax = axes[0][i] + venn2([set(truth[environment][truth[environment]['DEorthogroup']]['orthogroup'].values), + set(truth[environment][truth[environment]['DEgene']]['orthogroup'].values)], + labels, + ax=ax, set_colors=(palette['DE orthogroups'], palette['DE genes contained in orthogroups'])) + ax.set_title(settings['labels']['environments'].get(environment, environment)) + ax.set_ylabel("all orthogroups") + + ax = axes[1][i] + venn2([set(truth[environment][truth[environment]['DEorthogroup'] & (truth[environment]['OGsize'] == 'multi_gene_OG')]['orthogroup'].values), + set(truth[environment][truth[environment]['DEgene'] & (truth[environment]['OGsize'] == 'multi_gene_OG')]['orthogroup'].values)], + labels, + ax=ax, set_colors=(palette['DE orthogroups'], palette['DE genes contained in orthogroups'])) + ax.set_ylabel("only multi gene orthogroups") + + ax = axes[2][i] + venn2([set(truth[environment][truth[environment]['DEorthogroup'] & (truth[environment]['OGsize'] == 'single_gene_OG')]['orthogroup'].values), + set(truth[environment][truth[environment]['DEgene'] & (truth[environment]['OGsize'] == 'single_gene_OG')]['orthogroup'].values)], + labels, + ax=ax, set_colors=(palette['DE orthogroups'], palette['DE genes contained in orthogroups'])) + ax.set_ylabel("only single gene orthogroups") + + if i == 0: + for row in range(len(axes)): + axes[row][i].axison = True + for p in ['left', 'right', 'bottom', 'top']: + axes[row][i].spines[p].set_color('white') + axes[2][0].legend(handles=[mpatches.Patch(label=grp, facecolor=color, alpha=0.4) for (grp, color) in palette.items()], + bbox_to_anchor=(4.8, -0.), + ncols=2) + + return fig + + +def plot_heatmap(fp_orb_basedir:str, settings, num_columns:int=3, verbose=True): + # which environments to plot and in which order + environments = get_environments(fp_orb_basedir, settings) + # re-order environments such that two additional "environments" + # are spiked in for the color map and the combined environment + def _spikein(environments, num_cols=3, spikeelements=['colormap', 'all six environments']): + chunks = [environments[i:i+num_columns] for i in range(0, len(environments), num_columns)] + reordered = [] + for i, spike in enumerate(spikeelements): + reordered.extend(chunks[i]) + reordered.append(spike) + for chunk in chunks[len(spikeelements):]: + reordered.extend(chunk) + return reordered + ext_environments = _spikein(environments) + num_columns += 1 + + # load data + recovered_contigs = get_recovered_contigs(fp_orb_basedir, settings, verbose) + + fig, axes = plt.subplots( + int(np.ceil(len(ext_environments) / num_columns)) * 2, + num_columns, + figsize=(num_columns * 5, np.ceil((len(ext_environments) + 1) / num_columns) * 4), + height_ratios=[1, 5] * (len(ext_environments) // num_columns), + gridspec_kw={"hspace": 0.6, "wspace": 0.6} + ) + + for i, environment in tqdm(enumerate(ext_environments), disable=not verbose, desc='Compute heatmap'): + col = i % num_columns + row = i // num_columns + ax_dendro = axes[row * 2 + 0, col] + ax_heat = axes[row * 2 + 1, col] + + # panel labels + ax_dendro.text(-0.33, 1.05, chr(97+i), transform=ax_dendro.transAxes, fontsize=16, fontweight='bold',) + + if environment == 'colormap': + ax_dendro.axis("off") + ax_heat.set_position([ + ax_heat.get_position().x0, + ax_heat.get_position().y0, + ax_heat.get_position().width / 5, + ax_heat.get_position().height + ]) + ax_heat.set_title("Jaccard distance") + continue + pd.set_option('future.no_silent_downcasting', True) + if environment == 'all six environments': + features = pd.concat( + [pd.concat( + [pd.Series( + index=list(v['gene_name'].unique()), + data=True, + name=k, + ).rename_axis('gene_name') + for k, v in recovered_contigs[env].items() + ], axis=1).replace(np.nan, False).astype(bool) + for env in environments]) + else: + features = pd.concat( + [pd.Series( + index=list(v['gene_name'].unique()), + data=True, + name=k, + ).rename_axis('gene_name') + for k, v in recovered_contigs[environment].items() + ], axis=1).replace(np.nan, False).astype(bool) + + jaccard_distances = DistanceMatrix(squareform(pdist(features.T, metric='jaccard')), ids=[settings['labels']['assemblers'].get(ass, ass) for ass in features.columns]).to_data_frame() + + linkage_matrix = linkage(features.T, method='average', metric='jaccard')#, optimal_ordering=False) + col_dendro = dendrogram(linkage_matrix, no_plot=True) + col_order = col_dendro['leaves'] + sns.heatmap(jaccard_distances.iloc[col_order, col_order], ax=ax_heat, cbar=row==0 and col==num_columns-2, vmin=0, vmax=1, cbar_ax=axes[1, len(axes[0])-1]) + + dendrogram(linkage_matrix, ax=ax_dendro, color_threshold=0, no_labels=True) + ax_dendro.axis("off") + ax_dendro.set_title(settings['labels']['environments'].get(environment, environment)) + ax_dendro.set_position([ + ax_dendro.get_position().x0, + ax_heat.get_position().y1, + ax_dendro.get_position().width, + ax_dendro.get_position().height + ]) + + return fig + + +def plot_rnaquast(fp_orb_basedir:str, fp_quast_basedir:str, settings, num_columns:int=3, verbose=True): + # which environments to plot and in which order + environments = get_environments(fp_orb_basedir, settings) + + # load data + quast = getdata_rnaquast(fp_orb_basedir, fp_quast_basedir, settings) + data_recovery = getdata_recovery(fp_orb_basedir, settings, verbose) + + fig, axes = plt.subplots( + int(np.ceil(len(environments) / num_columns)), + num_columns * 2, figsize=(2 * num_columns * 4, np.ceil(len(environments) / num_columns) * 5), + gridspec_kw={"wspace": 0.31, "hspace": 0.3}) + + for i, environment in tqdm(enumerate(environments), disable=not verbose, desc='Draw RNAquast panels'): + order = list(data_recovery[data_recovery['environment'] == environment].sort_values(by='recovery_rank').index) + + # bad contigs + ax_bad = axes[i // num_columns, (i % num_columns) * 2] + sns.barplot(data=quast.loc[environment, :, 'Misassemblies'], x='score', y='assembler', ax=ax_bad, order=order, + color=settings['contig_classes']['multi_mapped_contigs_multi_og']['color']) + ax_bad.invert_xaxis() + ax_bad.set_ylabel("") + ax_bad.set_xlabel("number contigs") + ax_bad.set_title(settings['labels']['environments'].get(environment, environment), loc='right', horizontalalignment='center') + ax_top_bad = ax_bad.twiny() + ax_top_bad.xaxis.set_label_position('top') + ax_top_bad.set_xticks([]) + ax_top_bad.set_xlabel("weak") + ax_bad.xaxis.set_label_coords(1, -0.08) + ax_bad.text(-0.5, 1.05, chr(97+i), transform=ax_bad.transAxes, fontsize=16, fontweight='bold',) + + ax_good = axes[i // num_columns, (i % num_columns) * 2 + 1] + sns.barplot(data=quast.loc[environment, :, '95%-assembled isoforms'], x='score', y='assembler', ax=ax_good, order=order, + color=settings['contig_classes']['mapped_contigs']['color']) + ax_good.set_yticks([]) + ax_good.set_xlabel("robust") + ax_good.xaxis.set_label_position('top') + ax_good.set_ylabel("") + + # concat right (=good) axis directly adjacent to left (=bad) axis + ax_good.set_position([ + ax_bad.get_position().x1, + ax_good.get_position().y0, + ax_good.get_position().width, + ax_good.get_position().height]) + + # one legend for all panels + if i+1 == len(environments): + ax_good.legend(handles=[ + mpatches.Patch(color=settings['contig_classes']['multi_mapped_contigs_multi_og']['color'], label='Misassemblies'), + mpatches.Patch(color=settings['contig_classes']['mapped_contigs']['color'], label='95%-assembled isoforms')], + ncol=2, bbox_to_anchor=(-1.8, -0.20)) + + return fig \ No newline at end of file diff --git a/plotting/style.yaml b/plotting/style.yaml new file mode 100644 index 0000000..07fde24 --- /dev/null +++ b/plotting/style.yaml @@ -0,0 +1,62 @@ +# environments that are present as directories, but shall not be included into figures +skip_environments: + - 'fixed_insert' + +contig_classes: + mapped_contigs: + class: good + color: 'darkgreen' + label: 'recovered' + chimeric_mapped_contigs: + class: good + color: 'green' + label: 'overlapping annotation' + multi_mapped_contigs_single_og: + class: good + color: 'lightgreen' + label: 'orthologous recovered' + missed: + class: neutral + color: '#dddddd' + label: 'missed blocks' + unmapped_contigs: + class: bad + color: '#B84303' + label: 'hallucinated' + multi_mapped_contigs_multi_og: + class: bad + color: '#F00F18' + label: 'chimera' + single_mapped_contigs: + class: bad + color: '#FECB47' + label: 'incomplete' + length_filtered_contigs: + class: bad + color: '#FEEAA2' + label: 'length filtered' + +# rename ugly technical names with pretty tool / environment labels +labels: + assemblers: + dbg: 'dbg' + ibda_tran: 'IDBA-tran' + idba_mt: 'IDBA-MT' + idba_mt_fasta: 'IDBA-MT' + megahit: 'MEGAHIT' + oases: 'Oases' + rnaspades: 'rnaSPAdes' + soap-denovo-trans: 'SOAPdenovo-Trans' + soap_denovo_trans: 'SOAPdenovo-Trans' + transabyss: 'Trans-ABySS' + transabyss_transabyss: 'Trans-ABySS' + trinity: 'Trinity' + + environments: + oak: 'oak' + seawater: 'seawater' + healthy_gut: 'healthy gut' + moss: 'moss' + freshwater: 'freshwater' + diseased_gut: 'diseased gut' +