@@ -213,7 +213,7 @@ def attribution_one_target(
213213 return attributions_list
214214
215215
216- def inferrence (models , data_train_full_tensor , gene_names , xai_method ):
216+ def inferrence (models , data_train_full_tensor , gene_names , xai_method = 'GuidedBackprop' ):
217217
218218 """
219219 The main inferrence function to compute the entire GRN. Computes all
@@ -344,7 +344,7 @@ def attribution_one_model(
344344 return attributions
345345
346346
347- def inferrence_model_wise (models , data_train_full_tensor , gene_names , xai_method , n_models = [10 , 25 , 50 ], background_type = 'zeros' ):
347+ def inferrence_model_wise (models , data_train_full_tensor , gene_names , xai_method = 'GuidedBackprop' , n_models = [10 , 25 , 50 ], background_type = 'zeros' ):
348348
349349 """
350350 The main inferrence function to compute the entire GRN model wise. Computes all
@@ -430,10 +430,13 @@ def inferrence_model_wise(models, data_train_full_tensor, gene_names, xai_method
430430
431431 # cou = cou.merge(top_egde_collector, left_index = True, right_on='edge_key')
432432
433- grn_adata = attribution_to_anndata (attributions [keynames [0 ]], var = cou )
434433
435434 if len (keynames )> 0 :
435+ grn_adata = attribution_to_anndata (attributions [keynames [0 ]], var = cou )
436+
436437 for k in keynames [1 :len (keynames )]:
437438 # add remaining versions as masks
438439 grn_adata .layers [k ] = attributions [k ]
439- return grn_adata
440+ return grn_adata
441+ else :
442+ return None
0 commit comments