@@ -346,7 +346,36 @@ def attribution_one_model(
346346
347347def inferrence_model_wise (models , data_train_full_tensor , gene_names , xai_method , n_models = [10 , 25 , 50 ], background_type = 'zeros' ):
348348
349+ """
350+ The main inferrence function to compute the entire GRN model wise. Computes all
351+ attributions for all targets, aggregates them on the fly and creates an anndata.AnnData
352+ object with the edge names in the var slot.
353+
354+ Parameters
355+ ----------
356+ models : list[torch.Model]
357+ List of trained autoencoder models
358+
359+ data_train_full_tensor: torch.tensor
360+ input data tensor
361+
362+ gene_names: np.array
363+ Gene names indicating the order of the genes in the torch tensort
364+
365+ xai_method: str
366+ Method to be used [GradientShap, Deconvolution, GuidedBackprop]
367+
368+ n_models: list [int]
369+ returns aggregates of the attributions at these levels.
370+
371+ background_type: str
372+ Bacground to compute the LRP values against. One of ['zeros', 'randomize', 'data']
349373
374+ Returns
375+ -------
376+ grn_adata : anndata.AnnData
377+ A complete, aggregated GRN object
378+ """
350379
351380 tms = []
352381
@@ -361,13 +390,9 @@ def inferrence_model_wise(models, data_train_full_tensor, gene_names, xai_method
361390 explainer , xai_type = _get_explainer (trained_model , xai_method )
362391 tms .append (explainer )
363392
364-
365- thresholds = [0 , 0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 , 0.8 , 0.9 , 1.0 ]
366-
367393 attributions = {}
368394 attribution_collector = None
369395 keynames = []
370- top_egde_collector = {}
371396
372397
373398 for m in range (len (tms )):
@@ -380,19 +405,6 @@ def inferrence_model_wise(models, data_train_full_tensor, gene_names, xai_method
380405 background_type = background_type )
381406
382407
383- # grn_adata_eph = attribution_to_anndata(current_attribution, var=cou)
384- # b = np.argsort(grn_adata_eph.X, axis=1)
385- # grn_adata_eph.layers['sorted'] = b
386- # grn_adata_eph = edge_selection.add_top_edge_annotation_global(grn_adata=grn_adata_eph, top_edges = thresholds, key_name=f'agg_{m}')
387- # df_subset = grn_adata_eph.var.iloc[:, 2:]
388- # integral_results = df_subset.apply(
389- # lambda row: np.sum(integrate.cumulative_trapezoid(row, thresholds )),
390- # axis=1,
391- # )
392- # integral_results = integral_results/1000
393- # top_egde_collector[f'agg_{m}'] = integral_results
394-
395-
396408 if attribution_collector is not None :
397409 # add current attribution to the collector
398410 attribution_collector = aggregate_attributions ([attribution_collector , current_attribution ], strategy = 'sum' )
@@ -402,11 +414,13 @@ def inferrence_model_wise(models, data_train_full_tensor, gene_names, xai_method
402414 attribution_collector = current_attribution
403415
404416
405-
406- if (m + 1 ) in n_models :
407- # dont reset, just save the correct matrix
408- attributions [f'aggregated_{ (m + 1 )} ' ] = attribution_collector / (m + 1 )
409- keynames .append (f'aggregated_{ (m + 1 )} ' )
417+ try :
418+ if (m + 1 ) in n_models :
419+ # dont reset, just save the correct matrix
420+ attributions [f'aggregated_{ (m + 1 )} ' ] = attribution_collector / (m + 1 )
421+ keynames .append (f'aggregated_{ (m + 1 )} ' )
422+ except :
423+ pass
410424
411425
412426 # top_egde_collector = pd.DataFrame(top_egde_collector)
@@ -418,7 +432,8 @@ def inferrence_model_wise(models, data_train_full_tensor, gene_names, xai_method
418432
419433 grn_adata = attribution_to_anndata (attributions [keynames [0 ]], var = cou )
420434
421- for k in keynames [1 :len (keynames )]:
422- # add remaining versions as masks
423- grn_adata .layers [k ] = attributions [k ]
435+ if len (keynames )> 0 :
436+ for k in keynames [1 :len (keynames )]:
437+ # add remaining versions as masks
438+ grn_adata .layers [k ] = attributions [k ]
424439 return grn_adata
0 commit comments