Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ import tcri

sample_column="sample"
condition_column="source"
clonotype_column = "IR_VDJ_1_junction_aa"
phenotype_column="genevector"

tcri.pp.joint_distribution(adata,
sample_column=sample_column,
condition_column=condition_column,
condition_column=condition_column,
clonotype_column=clonotype_column,
phenotype_column=phenotype_column)
```

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
version='0.0.1',
description='Information theoretic metrics for single cell RNA and TCR sequencing.',
packages=find_packages(include=['tcri','tcri.metrics','tcri.preprocessing','tcri.plotting']),
install_requires=["scipy","numpy","notebook","sklearn","pandas","scanpy","tqdm","seaborn","matplotlib","pysankey"],
install_requires=["scipy","numpy","notebook","sklearn","pandas","scanpy","tqdm","seaborn","matplotlib","pysankeybeta"],
)
20 changes: 10 additions & 10 deletions tcri/metrics/_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def calc_entropy(P, log_units = 2):

def clonotypic_entropy(adata, phenotype):
clonotype_dist = collections.defaultdict(list)
for ct, clonotype in zip(adata.obs[adata.uns["phenotype_column"]], adata.obs["IR_VDJ_1_junction_aa"]):
for ct, clonotype in zip(adata.obs[adata.uns["phenotype_column"]], adata.obs[adata.uns["clonotype_column"]]):
if ct == phenotype:
prob = 1.0
else:
Expand Down Expand Up @@ -55,18 +55,18 @@ def phenotypic_entropy(adata, min_clone_size=3):
def phenotypic_flux(adata, from_this="Pre", to_that="Post", min_clone_size=3):
precounts = collections.defaultdict(lambda : collections.defaultdict(int))
postcounts = collections.defaultdict(lambda : collections.defaultdict(int))
for clone in tqdm.tqdm(list(set(adata.obs[adata.uns["clonotype_column"]]))):
phenotype_column = adata.uns["phenotype_column"]
clonotypes = adata.obs.groupby(adata.obs[adata.uns["clonotype_column"]])
for clone, sub in tqdm.tqdm(clonotypes):
if str(clone) == "nan": continue
sub = adata[adata.obs[adata.uns["clonotype_column"]]==clone].copy()
if len(sub.obs.index) < min_clone_size: continue
pre = sub[sub.obs[adata.uns["condition_column"]] == from_this]
post = sub[sub.obs[adata.uns["condition_column"]] == to_that]
for ph_pre in set(pre.obs["genevector"]):
for ph_post in set(post.obs["genevector"]):
precount = len(pre[pre.obs["genevector"]==ph_pre].obs.index)
postcount = len(post[post.obs["genevector"]==ph_post].obs.index)
if len(sub.index) < min_clone_size: continue
pre = sub[sub[adata.uns["condition_column"]] == from_this].value_counts(phenotype_column)
post = sub[sub[adata.uns["condition_column"]] == to_that].value_counts(phenotype_column)
for ph_pre, precount in pre.items():
for ph_post, postcount in post.items():
precounts[ph_pre][ph_post] += precount
postcounts[ph_pre][ph_post] += postcount

table = dict()
table["Pre"] = []
table["Post"] = []
Expand Down
4 changes: 3 additions & 1 deletion tcri/plotting/_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from ..metrics._metrics import phenotypic_entropy as pentropy
from ..metrics._metrics import phenotypic_flux as flux

def tcr_umap(adata, reduction="umap", top_n=10, filename="tcr_plot.png", seq_column="IR_VDJ_1_junction_aa"):
def tcr_umap(adata, reduction="umap", top_n=10, filename="tcr_plot.png", seq_column=None):
if not seq_column:
seq_column = adata.uns["clonotype_column"]
df = adata.obs
dft = df[df[seq_column].notnull()]
plt.figure(figsize = (10, 8))
Expand Down
86 changes: 56 additions & 30 deletions tcri/preprocessing/_joint_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import numpy as np
import pandas
import sys
import tqdm
from tqdm import tqdm
import numpy
import collections
import multiprocessing

def dicts2collated_arrays(dictsA, dictsB = []):
if type(dictsA) is dict: dictsA = [dictsA]
Expand All @@ -19,18 +20,52 @@ def dicts2collated_arrays(dictsA, dictsB = []):
else:
return arrayA, all_clonotypes

def initializer(clonotype_column, phenotype_column, min_clone_size, phenotypes):
global g_clonotype_column, g_phenotype_column, g_min_clone_size, g_phenotypes
g_clonotype_column = clonotype_column
g_phenotype_column = phenotype_column
g_min_clone_size = min_clone_size
g_phenotypes = phenotypes

def process_parallel(ta):
global g_clonotype_column, g_phenotype_column, g_min_clone_size, g_phenotypes
sample, tadata = ta
return process(sample, tadata, g_clonotype_column, g_phenotype_column, g_min_clone_size, g_phenotypes)

def process(sample, tadata, clonotype_column, phenotype_column, min_clone_size, phenotypes):
num_cells = 0
joint_d = []
clones = []
for clonotype in tqdm(list(set(tadata[clonotype_column].tolist())), ncols=50):
sub = tadata[tadata[clonotype_column]==clonotype]
if len(sub.index) < min_clone_size: continue
dist = []
for phenotype in phenotypes:
ph = phenotype+" Pseudo-probability"
if ph in sub.columns:
count = sum(sub[ph].tolist())
else:
count = np.count_nonzero(sub[phenotype_column] == phenotype)
dist.append(count)
num_cells += count
joint_d.append(dist)
clones.append(clonotype)
joint_d = numpy.array(joint_d)# / num_cells
_joint_d = {clone: dist for clone, dist in zip(clones, joint_d)}
return sample, _joint_d

def joint_distribution(adata, condition_column="condition", sample_column="sample",
ir_column="IR_VDJ_1_junction_aa", phenotype_column="phenotype", min_clone_size=1):
if ir_column not in adata.obs:
raise ValueError("{} not found in adata.obs.keys".format(ir_column))
clonotype_column="IR_VDJ_1_junction_aa", phenotype_column="phenotype", min_clone_size=1, cores=1):
if clonotype_column not in adata.obs:
raise ValueError("{} not found in adata.obs.keys".format(clonotype_column))
if phenotype_column not in adata.obs:
raise ValueError("{} not found in adata.obs.keys".format(phenotype))
if condition_column not in adata.obs:
raise ValueError("{} not found in adata.obs.keys".format(condition_column))
if sample_column not in adata.obs:
raise ValueError("{} not found in adata.obs.keys".format(sample_column))
phenotypes = list(set(adata.obs[phenotype_column]))
clonotypes = list(set(adata.obs[ir_column]))
clonotypes = list(set(adata.obs[clonotype_column]))
samples = list(set(adata.obs[sample_column]))
conditions = list(set(adata.obs[condition_column]))

Expand All @@ -39,33 +74,24 @@ def joint_distribution(adata, condition_column="condition", sample_column="sampl
adata.uns["sample_order"] = samples
adata.uns["condition_order"] = conditions


adata.uns["phenotype_column"] = phenotype_column
adata.uns["clonotype_column"] = ir_column
adata.uns["clonotype_column"] = clonotype_column
adata.uns["condition_column"] = condition_column
adata.uns["sample_column"] = sample_column

joint_ds = collections.defaultdict(lambda : collections.defaultdict(lambda : collections.defaultdict(dict)))
for condition in conditions:
cadata = adata[adata.obs[condition_column]==condition]
for sample in list(set(cadata.obs[sample_column])):
tadata = cadata[cadata.obs[sample_column]==sample]
num_cells = 0
joint_d = []
clones = []
for clonotype in tqdm.tqdm(list(set(tadata.obs[ir_column].tolist())), ncols=50):
sub = tadata[tadata.obs[ir_column]==clonotype]
if len(sub.obs.index.tolist()) < min_clone_size: continue
dist = []
for phenotype in adata.uns["phenotype_order"]:
count = sum(sub.obs[phenotype+" Pseudo-probability"].tolist())
dist.append(count)
num_cells += count
joint_d.append(dist)
clones.append(clonotype)
joint_d = numpy.array(joint_d)# / num_cells
_joint_d = dict()
for clone, dist in zip(clones,joint_d):
_joint_d[clone] = dist
joint_ds[condition][sample] = _joint_d
adata.uns["joint_probability_distribution"] = joint_ds
cadatas = adata.obs.groupby(condition_column)
for condition, cadata in cadatas:
tadatas = cadata.groupby(sample_column)

if cores > 1:
with multiprocessing.Pool(
processes=cores, initializer=initializer, initargs=(clonotype_column, phenotype_column, min_clone_size, phenotypes)
) as pool:
results = list(tqdm(pool.imap(process_parallel, tadatas), total=tadatas.ngroups))
joint_ds[condition] = {sample: _joint_d for sample, _joint_d in results}

else:
joint_ds[condition] = {sample : process(sample, tadata, clonotype_column, phenotype_column, min_clone_size, phenotypes)[1] for sample, tadata in tadatas}
adata.uns["joint_probability_distribution"] = joint_ds
return joint_ds