Skip to content

Commit 498c0fc

Browse files
update API docu
1 parent cdfb003 commit 498c0fc

1 file changed

Lines changed: 53 additions & 5 deletions

File tree

src/netmap/grn/inferrence.py

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,9 @@ def aggregate_attributions(attributions, strategy = 'mean'):
114114
return np.mean(attributions, axis = 0)
115115

116116

117-
def get_explainer(model, explainer_type, raw=False):
117+
def _get_explainer(model, explainer_type, raw=False):
118+
119+
118120
if explainer_type in ['GuidedBackprop', 'Deconvolution']:
119121
explainer_mode = 'lrp-like'
120122
else:
@@ -213,14 +215,39 @@ def attribution_one_target(
213215

214216
def inferrence(models, data_train_full_tensor, gene_names, xai_method):
215217

218+
"""
219+
The main inferrence function to compute the entire GRN. Computes all
220+
attributions for all targets, aggregates them and creates an anndata.AnnData
221+
object with the edge names in the var slot.
222+
223+
Parameters
224+
----------
225+
models : list[torch.Model]
226+
List of trained autoencoder models
227+
228+
data_train_full_tensor: torch.tensor
229+
input data tensor
230+
231+
gene_names: np.array
232+
Gene names indicating the order of the genes in the torch tensort
233+
234+
xai_method: str
235+
Method to be used [GradientShap, Deconvolution, GuidedBackprop]
236+
237+
Returns
238+
-------
239+
grn_adata : anndata.AnnData
240+
A complete, aggregated GRN object
241+
"""
242+
216243
tms = []
217244
name_list = []
218245
target_names = []
219246

220247

221248
for trained_model in models:
222249
trained_model.forward_mu_only = True
223-
explainer, xai_type = get_explainer(trained_model, xai_method)
250+
explainer, xai_type = _get_explainer(trained_model, xai_method)
224251
tms.append(explainer)
225252

226253
attributions = []
@@ -265,6 +292,27 @@ def attribution_one_model(
265292
xai_type='lrp-like',
266293
background_type = 'randomize'):
267294

295+
296+
"""
297+
Compute attribution for one model.
298+
299+
Parameters
300+
----------
301+
302+
lrp_model: list[LRP]
303+
List of LRP objects
304+
input_data: torch.tensor
305+
Tensor with the data the GRN should be computed for.
306+
xai_type: str
307+
Type of xai_model [Deconvolution, GradientShap, GuidedBackprop]
308+
309+
Returns
310+
-------
311+
attribution_list : np.ndarray
312+
Array containing the complete attributions for one model
313+
314+
"""
315+
268316
attributions_list = []
269317

270318
# Randomize backgorund for each round
@@ -298,9 +346,9 @@ def attribution_one_model(
298346

299347
def inferrence_model_wise(models, data_train_full_tensor, gene_names, xai_method, n_models = [10, 25, 50], background_type = 'zeros'):
300348

349+
350+
301351
tms = []
302-
name_list = []
303-
target_names = []
304352

305353
cou = [[f'{tup[0]}_{tup[1]}', tup[0], tup[1]] for tup in itertools.product(gene_names, gene_names)]
306354
cou = pd.DataFrame(cou)
@@ -310,7 +358,7 @@ def inferrence_model_wise(models, data_train_full_tensor, gene_names, xai_method
310358

311359
for trained_model in models:
312360
trained_model.forward_mu_only = True
313-
explainer, xai_type = get_explainer(trained_model, xai_method)
361+
explainer, xai_type = _get_explainer(trained_model, xai_method)
314362
tms.append(explainer)
315363

316364

0 commit comments

Comments
 (0)