@@ -169,8 +169,8 @@ def attribution_one_target(
169169 target_gene ,
170170 lrp_model ,
171171 input_data ,
172- background ,
173172 xai_type = 'lrp-like' ,
173+ background_type = 'zeros' ,
174174 randomize_background = False ) -> list :
175175
176176 """
@@ -194,13 +194,22 @@ def attribution_one_target(
194194 List of np.ndarrays containing all attribution matrices for
195195 the specified target.
196196 """
197+
198+ if background_type == 'randomize' :
199+ background = shuffle_each_column_independently (input_data )
200+ elif background_type == 'zeros' :
201+ background = torch .zeros ((1 , input_data .shape [1 ]))
202+ background = background .cuda ()
203+ elif background_type == 'data' :
204+ background = input_data
205+ else :
206+ background = torch .zeros ((1 , input_data .shape [1 ]))
207+ background = background .cuda ()
208+
197209
198210 attributions_list = []
199211 for m in range (len (lrp_model )):
200212 # Randomize backgorund for each round
201- if randomize_background :
202- background = shuffle_each_column_independently (background )
203-
204213 model = lrp_model [m ]
205214 #for _ in range(num_iterations):
206215 if xai_type == 'lrp-like' :
@@ -213,7 +222,7 @@ def attribution_one_target(
213222 return attributions_list
214223
215224
216- def inferrence (models , data_train_full_tensor , gene_names , xai_method = 'GradientShap' ):
225+ def inferrence (models , data_train_full_tensor , gene_names , xai_method = 'GradientShap' , background_type == 'zeros' ):
217226
218227 """
219228 The main inferrence function to compute the entire GRN. Computes all
@@ -258,8 +267,8 @@ def inferrence(models, data_train_full_tensor, gene_names, xai_method='GradientS
258267 g ,
259268 tms ,
260269 data_train_full_tensor ,
261- data_train_full_tensor ,
262270 xai_type = xai_type ,
271+ background_type = background_type ,
263272 randomize_background = True )
264273
265274
0 commit comments