From a51af320da046966981dd290947a00542625eec8 Mon Sep 17 00:00:00 2001 From: jiwen90 <70122688+jiwen90@users.noreply.github.com> Date: Fri, 27 Jan 2023 02:24:55 -0500 Subject: [PATCH] refactor for 10x speedup --- README.md | 4 +- setup.py | 2 +- tcri/metrics/_metrics.py | 20 +++--- tcri/plotting/_plotting.py | 4 +- tcri/preprocessing/_joint_distribution.py | 86 +++++++++++++++-------- 5 files changed, 73 insertions(+), 43 deletions(-) diff --git a/README.md b/README.md index fab3160..fa49a26 100644 --- a/README.md +++ b/README.md @@ -23,11 +23,13 @@ import tcri sample_column="sample" condition_column="source" +clonotype_column = "IR_VDJ_1_junction_aa" phenotype_column="genevector" tcri.pp.joint_distribution(adata, sample_column=sample_column, - condition_column=condition_column, + condition_column=condition_column, + clonotype_column=clonotype_column, phenotype_column=phenotype_column) ``` diff --git a/setup.py b/setup.py index bc8aa62..7279a9b 100644 --- a/setup.py +++ b/setup.py @@ -5,5 +5,5 @@ version='0.0.1', description='Information theoretic metrics for single cell RNA and TCR sequencing.', packages=find_packages(include=['tcri','tcri.metrics','tcri.preprocessing','tcri.plotting']), - install_requires=["scipy","numpy","notebook","sklearn","pandas","scanpy","tqdm","seaborn","matplotlib","pysankey"], + install_requires=["scipy","numpy","notebook","sklearn","pandas","scanpy","tqdm","seaborn","matplotlib","pysankeybeta"], ) diff --git a/tcri/metrics/_metrics.py b/tcri/metrics/_metrics.py index 4a843fc..b3e1786 100644 --- a/tcri/metrics/_metrics.py +++ b/tcri/metrics/_metrics.py @@ -9,7 +9,7 @@ def calc_entropy(P, log_units = 2): def clonotypic_entropy(adata, phenotype): clonotype_dist = collections.defaultdict(list) - for ct, clonotype in zip(adata.obs[adata.uns["phenotype_column"]], adata.obs["IR_VDJ_1_junction_aa"]): + for ct, clonotype in zip(adata.obs[adata.uns["phenotype_column"]], adata.obs[adata.uns["clonotype_column"]]): if ct == phenotype: prob = 1.0 else: @@ -55,18 +55,18 @@ def phenotypic_entropy(adata, min_clone_size=3): def phenotypic_flux(adata, from_this="Pre", to_that="Post", min_clone_size=3): precounts = collections.defaultdict(lambda : collections.defaultdict(int)) postcounts = collections.defaultdict(lambda : collections.defaultdict(int)) - for clone in tqdm.tqdm(list(set(adata.obs[adata.uns["clonotype_column"]]))): + phenotype_column = adata.uns["phenotype_column"] + clonotypes = adata.obs.groupby(adata.obs[adata.uns["clonotype_column"]]) + for clone, sub in tqdm.tqdm(clonotypes): if str(clone) == "nan": continue - sub = adata[adata.obs[adata.uns["clonotype_column"]]==clone].copy() - if len(sub.obs.index) < min_clone_size: continue - pre = sub[sub.obs[adata.uns["condition_column"]] == from_this] - post = sub[sub.obs[adata.uns["condition_column"]] == to_that] - for ph_pre in set(pre.obs["genevector"]): - for ph_post in set(post.obs["genevector"]): - precount = len(pre[pre.obs["genevector"]==ph_pre].obs.index) - postcount = len(post[post.obs["genevector"]==ph_post].obs.index) + if len(sub.index) < min_clone_size: continue + pre = sub[sub[adata.uns["condition_column"]] == from_this].value_counts(phenotype_column) + post = sub[sub[adata.uns["condition_column"]] == to_that].value_counts(phenotype_column) + for ph_pre, precount in pre.items(): + for ph_post, postcount in post.items(): precounts[ph_pre][ph_post] += precount postcounts[ph_pre][ph_post] += postcount + table = dict() table["Pre"] = [] table["Post"] = [] diff --git a/tcri/plotting/_plotting.py b/tcri/plotting/_plotting.py index 2fc69a6..d43d1cc 100644 --- a/tcri/plotting/_plotting.py +++ b/tcri/plotting/_plotting.py @@ -10,7 +10,9 @@ from ..metrics._metrics import phenotypic_entropy as pentropy from ..metrics._metrics import phenotypic_flux as flux -def tcr_umap(adata, reduction="umap", top_n=10, filename="tcr_plot.png", seq_column="IR_VDJ_1_junction_aa"): +def tcr_umap(adata, reduction="umap", top_n=10, filename="tcr_plot.png", seq_column=None): + if not seq_column: + seq_column = adata.uns["clonotype_column"] df = adata.obs dft = df[df[seq_column].notnull()] plt.figure(figsize = (10, 8)) diff --git a/tcri/preprocessing/_joint_distribution.py b/tcri/preprocessing/_joint_distribution.py index e23a260..79e9ae9 100644 --- a/tcri/preprocessing/_joint_distribution.py +++ b/tcri/preprocessing/_joint_distribution.py @@ -2,9 +2,10 @@ import numpy as np import pandas import sys -import tqdm +from tqdm import tqdm import numpy import collections +import multiprocessing def dicts2collated_arrays(dictsA, dictsB = []): if type(dictsA) is dict: dictsA = [dictsA] @@ -19,10 +20,44 @@ def dicts2collated_arrays(dictsA, dictsB = []): else: return arrayA, all_clonotypes +def initializer(clonotype_column, phenotype_column, min_clone_size, phenotypes): + global g_clonotype_column, g_phenotype_column, g_min_clone_size, g_phenotypes + g_clonotype_column = clonotype_column + g_phenotype_column = phenotype_column + g_min_clone_size = min_clone_size + g_phenotypes = phenotypes + +def process_parallel(ta): + global g_clonotype_column, g_phenotype_column, g_min_clone_size, g_phenotypes + sample, tadata = ta + return process(sample, tadata, g_clonotype_column, g_phenotype_column, g_min_clone_size, g_phenotypes) + +def process(sample, tadata, clonotype_column, phenotype_column, min_clone_size, phenotypes): + num_cells = 0 + joint_d = [] + clones = [] + for clonotype in tqdm(list(set(tadata[clonotype_column].tolist())), ncols=50): + sub = tadata[tadata[clonotype_column]==clonotype] + if len(sub.index) < min_clone_size: continue + dist = [] + for phenotype in phenotypes: + ph = phenotype+" Pseudo-probability" + if ph in sub.columns: + count = sum(sub[ph].tolist()) + else: + count = np.count_nonzero(sub[phenotype_column] == phenotype) + dist.append(count) + num_cells += count + joint_d.append(dist) + clones.append(clonotype) + joint_d = numpy.array(joint_d)# / num_cells + _joint_d = {clone: dist for clone, dist in zip(clones, joint_d)} + return sample, _joint_d + def joint_distribution(adata, condition_column="condition", sample_column="sample", - ir_column="IR_VDJ_1_junction_aa", phenotype_column="phenotype", min_clone_size=1): - if ir_column not in adata.obs: - raise ValueError("{} not found in adata.obs.keys".format(ir_column)) + clonotype_column="IR_VDJ_1_junction_aa", phenotype_column="phenotype", min_clone_size=1, cores=1): + if clonotype_column not in adata.obs: + raise ValueError("{} not found in adata.obs.keys".format(clonotype_column)) if phenotype_column not in adata.obs: raise ValueError("{} not found in adata.obs.keys".format(phenotype)) if condition_column not in adata.obs: @@ -30,7 +65,7 @@ def joint_distribution(adata, condition_column="condition", sample_column="sampl if sample_column not in adata.obs: raise ValueError("{} not found in adata.obs.keys".format(sample_column)) phenotypes = list(set(adata.obs[phenotype_column])) - clonotypes = list(set(adata.obs[ir_column])) + clonotypes = list(set(adata.obs[clonotype_column])) samples = list(set(adata.obs[sample_column])) conditions = list(set(adata.obs[condition_column])) @@ -39,33 +74,24 @@ def joint_distribution(adata, condition_column="condition", sample_column="sampl adata.uns["sample_order"] = samples adata.uns["condition_order"] = conditions - adata.uns["phenotype_column"] = phenotype_column - adata.uns["clonotype_column"] = ir_column + adata.uns["clonotype_column"] = clonotype_column adata.uns["condition_column"] = condition_column adata.uns["sample_column"] = sample_column joint_ds = collections.defaultdict(lambda : collections.defaultdict(lambda : collections.defaultdict(dict))) - for condition in conditions: - cadata = adata[adata.obs[condition_column]==condition] - for sample in list(set(cadata.obs[sample_column])): - tadata = cadata[cadata.obs[sample_column]==sample] - num_cells = 0 - joint_d = [] - clones = [] - for clonotype in tqdm.tqdm(list(set(tadata.obs[ir_column].tolist())), ncols=50): - sub = tadata[tadata.obs[ir_column]==clonotype] - if len(sub.obs.index.tolist()) < min_clone_size: continue - dist = [] - for phenotype in adata.uns["phenotype_order"]: - count = sum(sub.obs[phenotype+" Pseudo-probability"].tolist()) - dist.append(count) - num_cells += count - joint_d.append(dist) - clones.append(clonotype) - joint_d = numpy.array(joint_d)# / num_cells - _joint_d = dict() - for clone, dist in zip(clones,joint_d): - _joint_d[clone] = dist - joint_ds[condition][sample] = _joint_d - adata.uns["joint_probability_distribution"] = joint_ds \ No newline at end of file + cadatas = adata.obs.groupby(condition_column) + for condition, cadata in cadatas: + tadatas = cadata.groupby(sample_column) + + if cores > 1: + with multiprocessing.Pool( + processes=cores, initializer=initializer, initargs=(clonotype_column, phenotype_column, min_clone_size, phenotypes) + ) as pool: + results = list(tqdm(pool.imap(process_parallel, tadatas), total=tadatas.ngroups)) + joint_ds[condition] = {sample: _joint_d for sample, _joint_d in results} + + else: + joint_ds[condition] = {sample : process(sample, tadata, clonotype_column, phenotype_column, min_clone_size, phenotypes)[1] for sample, tadata in tadatas} + adata.uns["joint_probability_distribution"] = joint_ds + return joint_ds