@@ -114,7 +114,9 @@ def aggregate_attributions(attributions, strategy = 'mean'):
114114 return np .mean (attributions , axis = 0 )
115115
116116
117- def get_explainer (model , explainer_type , raw = False ):
117+ def _get_explainer (model , explainer_type , raw = False ):
118+
119+
118120 if explainer_type in ['GuidedBackprop' , 'Deconvolution' ]:
119121 explainer_mode = 'lrp-like'
120122 else :
@@ -213,14 +215,39 @@ def attribution_one_target(
213215
214216def inferrence (models , data_train_full_tensor , gene_names , xai_method ):
215217
218+ """
219+ The main inferrence function to compute the entire GRN. Computes all
220+ attributions for all targets, aggregates them and creates an anndata.AnnData
221+ object with the edge names in the var slot.
222+
223+ Parameters
224+ ----------
225+ models : list[torch.Model]
226+ List of trained autoencoder models
227+
228+ data_train_full_tensor: torch.tensor
229+ input data tensor
230+
231+ gene_names: np.array
232+ Gene names indicating the order of the genes in the torch tensort
233+
234+ xai_method: str
235+ Method to be used [GradientShap, Deconvolution, GuidedBackprop]
236+
237+ Returns
238+ -------
239+ grn_adata : anndata.AnnData
240+ A complete, aggregated GRN object
241+ """
242+
216243 tms = []
217244 name_list = []
218245 target_names = []
219246
220247
221248 for trained_model in models :
222249 trained_model .forward_mu_only = True
223- explainer , xai_type = get_explainer (trained_model , xai_method )
250+ explainer , xai_type = _get_explainer (trained_model , xai_method )
224251 tms .append (explainer )
225252
226253 attributions = []
@@ -265,6 +292,27 @@ def attribution_one_model(
265292 xai_type = 'lrp-like' ,
266293 background_type = 'randomize' ):
267294
295+
296+ """
297+ Compute attribution for one model.
298+
299+ Parameters
300+ ----------
301+
302+ lrp_model: list[LRP]
303+ List of LRP objects
304+ input_data: torch.tensor
305+ Tensor with the data the GRN should be computed for.
306+ xai_type: str
307+ Type of xai_model [Deconvolution, GradientShap, GuidedBackprop]
308+
309+ Returns
310+ -------
311+ attribution_list : np.ndarray
312+ Array containing the complete attributions for one model
313+
314+ """
315+
268316 attributions_list = []
269317
270318 # Randomize backgorund for each round
@@ -298,9 +346,9 @@ def attribution_one_model(
298346
299347def inferrence_model_wise (models , data_train_full_tensor , gene_names , xai_method , n_models = [10 , 25 , 50 ], background_type = 'zeros' ):
300348
349+
350+
301351 tms = []
302- name_list = []
303- target_names = []
304352
305353 cou = [[f'{ tup [0 ]} _{ tup [1 ]} ' , tup [0 ], tup [1 ]] for tup in itertools .product (gene_names , gene_names )]
306354 cou = pd .DataFrame (cou )
@@ -310,7 +358,7 @@ def inferrence_model_wise(models, data_train_full_tensor, gene_names, xai_method
310358
311359 for trained_model in models :
312360 trained_model .forward_mu_only = True
313- explainer , xai_type = get_explainer (trained_model , xai_method )
361+ explainer , xai_type = _get_explainer (trained_model , xai_method )
314362 tms .append (explainer )
315363
316364
0 commit comments