@@ -221,7 +221,7 @@ def attribution_one_target(
221221 return attributions_list
222222
223223
224- def inferrence (models , data_train_full_tensor , gene_names , xai_method = 'GradientShap' , background_type = 'zeros' , raw = False ):
224+ def inferrence (models , data_train_full_tensor , gene_names , xai_method = 'GradientShap' , background_type = 'zeros' , backing_file = 'grn_adata.h5' , return_in_memory = False ):
225225
226226 """
227227 The main inferrence function to compute the entire GRN. Computes all
@@ -255,40 +255,84 @@ def inferrence(models, data_train_full_tensor, gene_names, xai_method='GradientS
255255
256256 for trained_model in models :
257257 trained_model .forward_mu_only = True
258- explainer , xai_type = _get_explainer (trained_model , xai_method , raw = raw )
258+ explainer , xai_type = _get_explainer (trained_model , xai_method , raw = False )
259259 tms .append (explainer )
260260
261261 attributions = []
262262
263- for g in tqdm (range (data_train_full_tensor .shape [1 ])):
264- attributions_list = attribution_one_target (
265- g ,
266- tms ,
267- data_train_full_tensor ,
268- xai_type = xai_method ,
269- background_type = background_type )
263+ rows = data_train_full_tensor .shape [0 ]
264+ cols = data_train_full_tensor .shape [1 ]
265+ cols_grn = cols * cols
270266
271-
272- attributions_list = aggregate_attributions (attributions_list , strategy = 'mean' )
273- attributions .append (attributions_list )
267+ if backing_file is not None :
268+ with h5py .File (backing_file , 'w' ) as f :
269+
270+ dset = f .create_dataset (
271+ 'data' ,
272+ shape = (rows , cols_grn ),
273+ dtype = 'float32' ,
274+ chunks = (rows , cols )
275+ )
276+
277+ for g in tqdm (range (data_train_full_tensor .shape [1 ])):
278+ attributions_list = attribution_one_target (
279+ g ,
280+ tms ,
281+ data_train_full_tensor ,
282+ xai_type = xai_type ,
283+ background_type = background_type )
284+
285+
286+
287+ attributions_list = aggregate_attributions (attributions_list , strategy = 'mean' )
288+ dset [:, (g * cols ): ((g + 1 )* cols )] = attributions_list
274289
275- ## AGGREGATION: REPLACE LIST BY AGGREGATED DATA
276- for i in range (len (attributions )):
290+ else :
291+ for g in tqdm (range (data_train_full_tensor .shape [1 ])):
292+ attributions_list = attribution_one_target (
293+ g ,
294+ tms ,
295+ data_train_full_tensor ,
296+ xai_type = xai_type ,
297+ background_type = background_type )
298+
299+
300+
301+ attributions_list = aggregate_attributions (attributions_list , strategy = 'mean' )
302+ attributions .append (attributions_list )
303+
304+ attributions = np .hstack (attributions )
277305
306+ for i in range (cols ):
278307 ## Create name vector
279308 name_list = name_list + list (gene_names )
280309 target_names = target_names + [gene_names [i ]] * len (gene_names )
281310
282-
283-
284- attributions = np .hstack (attributions )
285311
286312 index_list = [f"{ s } _{ t } " for (s , t ) in zip (name_list , target_names )]
287313 cou = pd .DataFrame ({'index' : index_list , 'source' :name_list , 'target' :target_names })
288314 cou = cou .set_index ('index' )
289315
290- grn_adata = attribution_to_anndata (attributions , var = cou )
316+ if backing_file is not None :
317+ if return_in_memory :
318+ with h5py .File (backing_file , 'r+' ) as f :
319+ dset = f ['data' ]
320+ grn_adata = ad .AnnData (dset , uns = {'backing_file' : backing_file }, var = cou )
321+ grn_adata = grn_adata .to_memory ()
322+
323+ else :
324+ grn_adata = ad .AnnData (shape = (rows , cols_grn ), uns = {'backing_file' : backing_file }, var = cou )
325+ else :
326+ grn_adata = attribution_to_anndata (attributions , var = cou )
327+
328+ return grn_adata
329+
291330
331+ def return_grn_adata_to_memory (grn_adata ):
332+ with h5py .File (grn_adata .uns ['backing_file' ], 'r+' ) as f :
333+ dset = f ['data' ]
334+ grn_adata .X = dset
335+ grn_adata = grn_adata .to_memory ()
292336 return grn_adata
293337
294338
0 commit comments