@@ -227,6 +227,12 @@ def attribution_one_target(
227227 attributions_list .append (attribution .detach ().cpu ().numpy ())
228228 return attributions_list
229229
230+ import pyarrow as pa
231+ import pyarrow .dataset as ds
232+ import pyarrow .parquet as pq
233+ import numpy as np
234+ import os
235+ from tqdm import tqdm
230236
231237def inferrence (models , data_train_full_tensor , gene_names , xai_method = 'GradientShap' , background_type = 'zeros' , backing_file = 'grn_adata.h5' , return_in_memory = False ):
232238
@@ -278,45 +284,49 @@ def inferrence(models, data_train_full_tensor, gene_names, xai_method='GradientS
278284
279285 if backing_file is not None :
280286
281- dummy_data = np . zeros (( rows , cols ), dtype = "float32" )
282- column_names = [ f"col_ { i } " for i in range ( cols )]
283- dummy_table = pa . table ({ name : dummy_data [:, i ] for i , name in enumerate ( column_names )} )
287+ # Configuration
288+ output_dir = op . dirname ( backing_file )
289+ os . makedirs ( output_dir , exist_ok = True )
284290
285- # Arrow IPC writer with zstd compression
286- writer = ipc .new_file (
287- backing_file ,
288- dummy_table .schema ,
289- options = ipc .IpcWriteOptions (compression = "zstd" )
290- )
291+ name_list = list (gene_names )
292+ name = 'attr'
293+
294+ for i in range (cols ):
295+ ## Create name vector
296+ name_list = name_list + list (gene_names )
297+ target_names = target_names + [gene_names [i ]] * len (gene_names )
298+ column_names = [f'{ s } _{ t } ' for s ,t in zip (name_list , target_names )]
291299
292- for g in tqdm ( range ( data_train_full_tensor . shape [ 1 ])):
300+ schema = pa . schema ([( name , pa . float32 ()) for name in column_names ])
293301
302+ # Loop through your column-wise groups
303+ for g in tqdm (range (data_train_full_tensor .shape [1 ])):
304+ # Generate your column-wise chunk (shape: [rows, cols])
294305 attributions_list = attribution_one_target (
295- g ,
296- tms ,
297- data_train_full_tensor ,
298- xai_type = xai_type ,
299- background_type = background_type
306+ g , tms , data_train_full_tensor ,
307+ xai_type = xai_type , background_type = background_type
300308 )
301-
309+
302310 attributions_list = aggregate_attributions (attributions_list , strategy = 'mean' )
303-
311+
304312 collect_sums .append (np .sum (attributions_list , axis = 0 ))
305313 collect_means .append (np .mean (attributions_list , axis = 0 ))
306314
307- source_list = list (gene_names )
308- target_names = [gene_names [g ]] * len (gene_names )
309- edge_names = [f'{ s } _{ t } ' for s ,t in zip (source_list , target_names )]
310-
315+ # 2. Convert the column-chunk to a PyArrow Table
316+ # Map the numpy chunk to the specific column names for this group 'g'
317+ current_col_names = column_names [g * cols : (g + 1 )* cols ]
311318
312- table = pa .table ({
313- edge_names [i ]: attributions_list [:, i ]
314- for i in range (attributions_list .shape [1 ])
315- })
319+ # We create a table where each slice of the numpy array is a column
320+ chunk_table = pa .table (
321+ [attributions_list [:, i ] for i in range (attributions_list .shape [1 ])],
322+ names = current_col_names
323+ )
316324
317- writer .write_table (table )
325+ # 3. Write this specific column-group to a Parquet file
326+ # In a dataset, these will be "sharded" columns
327+ file_path = os .path .join (output_dir , f"{ gene_names [i ]} .parquet" )
328+ pq .write_table (chunk_table , file_path )
318329
319- writer .close ()
320330
321331
322332 else :
0 commit comments