Skip to content

Commit 61087cd

Browse files
efficient version
1 parent 50072f5 commit 61087cd

1 file changed

Lines changed: 15 additions & 6 deletions

File tree

src/netmap/grn/inferrence.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,8 @@ def attribution_one_target(
169169
target_gene,
170170
lrp_model,
171171
input_data,
172-
background,
173172
xai_type='lrp-like',
173+
background_type = 'zeros',
174174
randomize_background = False) -> list:
175175

176176
"""
@@ -194,13 +194,22 @@ def attribution_one_target(
194194
List of np.ndarrays containing all attribution matrices for
195195
the specified target.
196196
"""
197+
198+
if background_type == 'randomize':
199+
background = shuffle_each_column_independently(input_data)
200+
elif background_type == 'zeros':
201+
background = torch.zeros((1, input_data.shape[1]))
202+
background = background.cuda()
203+
elif background_type == 'data':
204+
background = input_data
205+
else:
206+
background = torch.zeros((1, input_data.shape[1]))
207+
background = background.cuda()
208+
197209

198210
attributions_list = []
199211
for m in range(len(lrp_model)):
200212
# Randomize backgorund for each round
201-
if randomize_background:
202-
background = shuffle_each_column_independently(background)
203-
204213
model = lrp_model[m]
205214
#for _ in range(num_iterations):
206215
if xai_type == 'lrp-like':
@@ -213,7 +222,7 @@ def attribution_one_target(
213222
return attributions_list
214223

215224

216-
def inferrence(models, data_train_full_tensor, gene_names, xai_method='GradientShap'):
225+
def inferrence(models, data_train_full_tensor, gene_names, xai_method='GradientShap', background_type == 'zeros'):
217226

218227
"""
219228
The main inferrence function to compute the entire GRN. Computes all
@@ -258,8 +267,8 @@ def inferrence(models, data_train_full_tensor, gene_names, xai_method='GradientS
258267
g,
259268
tms,
260269
data_train_full_tensor,
261-
data_train_full_tensor,
262270
xai_type=xai_type,
271+
background_type= background_type,
263272
randomize_background = True)
264273

265274

0 commit comments

Comments
 (0)