Skip to content

Commit e12c43b

Browse files
correct neighborhood masking and utils for adding regulons
1 parent 32724cd commit e12c43b

2 files changed

Lines changed: 45 additions & 5 deletions

File tree

src/netmap/masking/internal.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def create_pairwise_binary_mask(binary_matrix, gene_list):
7777

7878
gene_pairs_indices = list(itertools.combinations(range(num_genes), 2))
7979
for g1_idx, g2_idx in gene_pairs_indices:
80-
mask = binary_matrix[:, g1_idx] * binary_matrix[:, g2_idx]
80+
mask = np.multiply(binary_matrix[:, g1_idx] , binary_matrix[:, g2_idx])
8181
key_fwd = f"{gene_list[g1_idx]}_{gene_list[g2_idx]}"
8282
pairwise_mask_dict[key_fwd] = mask
8383
key_rev = f"{gene_list[g2_idx]}_{gene_list[g1_idx]}"
@@ -102,13 +102,14 @@ def dict_to_dataframe(mask_dict, column_order_list):
102102
pd.DataFrame: A DataFrame with masks as columns, in the specified order.
103103
"""
104104
# 1. Create a dictionary with only the ordered columns
105-
ordered_data = {col: mask_dict[col] for col in column_order_list if col in mask_dict}
105+
ordered_data = {col: np.asarray(mask_dict[col]).squeeze() for col in column_order_list if col in mask_dict}
106106

107107
# 2. Check if all specified columns were found
108108
if len(ordered_data) != len(column_order_list):
109109
missing_columns = set(column_order_list) - set(ordered_data.keys())
110110
print(f"Warning: The following columns were not found in the mask dictionary: {missing_columns}")
111111

112+
print(ordered_data)
112113
# 3. Create the DataFrame from the ordered dictionary
113114
df = pd.DataFrame(ordered_data)
114115

@@ -117,7 +118,7 @@ def dict_to_dataframe(mask_dict, column_order_list):
117118
def binarize_adata(adata, expression_threshold = 0):
118119

119120
if issparse(adata.X):
120-
binary_expression = (adata.X > expression_threshold).astype(int).tocsr()
121+
binary_expression = (adata.X.todense() > expression_threshold).astype(int)
121122
else:
122123
binary_expression = (adata.X > expression_threshold).astype(int)
123124
return binary_expression
@@ -140,9 +141,11 @@ def add_neighbourhood_expression_mask(adata, grn_adata, strict=False):
140141
ne = get_neighborhood_expression(adata, required_neighbours=5)
141142
else:
142143
ne = binarize_adata(adata)
143-
mask = create_pairwise_binary_mask(ne, adata.var.index)
144+
mask = create_pairwise_binary_mask(ne, list(adata.var.index))
145+
144146
mask = dict_to_dataframe(mask, column_order_list = grn_adata.var.index)
145147
grn_adata.layers['mask'] = mask
148+
grn_adata.var['count_nonzero'] = np.sum(grn_adata.layers['mask'], axis =0)
146149
return grn_adata
147150

148151

src/netmap/utils/data_utils.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
import os
33
import os.path as op
44

5+
import pandas as pd
6+
import scipy.sparse
7+
58
def attribution_to_anndata(attribution_list, var = None, obs = None)-> anndata.AnnData:
69

710
"""
@@ -29,4 +32,38 @@ def create_output_directory(result_params):
2932

3033

3134
def save_anndata(adobj, result_params):
32-
adobj.write( filename = op.join(result_params['output_directory'], result_params['adata_filename']))
35+
adobj.write( filename = op.join(result_params['output_directory'], result_params['adata_filename']))
36+
37+
38+
39+
def merge_all_to_obs(target_adata, source_adata, replace=True):
40+
"""
41+
Takes all variables from source_adata and appends them as columns
42+
to target_adata.obs for easy plotting.
43+
"""
44+
if target_adata.n_obs != source_adata.n_obs:
45+
raise ValueError("Cell counts do not match between objects.")
46+
47+
48+
if scipy.sparse.issparse(source_adata.X):
49+
source_data = source_adata.X.toarray()
50+
else:
51+
source_data = source_adata.X
52+
53+
# Create a DataFrame from the source data
54+
source_df = pd.DataFrame(
55+
source_data,
56+
index=source_adata.obs_names,
57+
columns=source_adata.var_names
58+
)
59+
60+
# Check if regulon cols are already present, and delte all regulon columns
61+
if len(set(target_adata.obs.columns).intersection(list(source_df.columns)))>0:
62+
if replace:
63+
spike_cols = [col for col in target_adata.obs.columns if 'regulon' in col]
64+
target_adata.obs = target_adata.obs.drop(columns = spike_cols)
65+
target_adata.obs = pd.concat([target_adata.obs, source_df], axis=1)
66+
else:
67+
print('Regulon columns where present and not replaced.')
68+
69+
return target_adata

0 commit comments

Comments
 (0)