Skip to content

Commit c258d3b

Browse files
change to parquet
1 parent 2b4b025 commit c258d3b

1 file changed

Lines changed: 37 additions & 27 deletions

File tree

src/netmap/grn/inferrence.py

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,12 @@ def attribution_one_target(
227227
attributions_list.append(attribution.detach().cpu().numpy())
228228
return attributions_list
229229

230+
import pyarrow as pa
231+
import pyarrow.dataset as ds
232+
import pyarrow.parquet as pq
233+
import numpy as np
234+
import os
235+
from tqdm import tqdm
230236

231237
def inferrence(models, data_train_full_tensor, gene_names, xai_method='GradientShap', background_type = 'zeros', backing_file='grn_adata.h5', return_in_memory=False):
232238

@@ -278,45 +284,49 @@ def inferrence(models, data_train_full_tensor, gene_names, xai_method='GradientS
278284

279285
if backing_file is not None:
280286

281-
dummy_data = np.zeros((rows, cols), dtype="float32")
282-
column_names = [f"col_{i}" for i in range(cols)]
283-
dummy_table = pa.table({name: dummy_data[:, i] for i, name in enumerate(column_names)})
287+
# Configuration
288+
output_dir = op.dirname(backing_file)
289+
os.makedirs(output_dir, exist_ok=True)
284290

285-
# Arrow IPC writer with zstd compression
286-
writer = ipc.new_file(
287-
backing_file,
288-
dummy_table.schema,
289-
options=ipc.IpcWriteOptions(compression="zstd")
290-
)
291+
name_list = list(gene_names)
292+
name = 'attr'
293+
294+
for i in range(cols):
295+
## Create name vector
296+
name_list = name_list + list(gene_names)
297+
target_names = target_names+[gene_names[i]] *len(gene_names)
298+
column_names = [f'{s}_{t}' for s,t in zip(name_list, target_names)]
291299

292-
for g in tqdm(range(data_train_full_tensor.shape[1])):
300+
schema = pa.schema([(name, pa.float32()) for name in column_names])
293301

302+
# Loop through your column-wise groups
303+
for g in tqdm(range(data_train_full_tensor.shape[1])):
304+
# Generate your column-wise chunk (shape: [rows, cols])
294305
attributions_list = attribution_one_target(
295-
g,
296-
tms,
297-
data_train_full_tensor,
298-
xai_type=xai_type,
299-
background_type=background_type
306+
g, tms, data_train_full_tensor,
307+
xai_type=xai_type, background_type=background_type
300308
)
301-
309+
302310
attributions_list = aggregate_attributions(attributions_list, strategy='mean')
303-
311+
304312
collect_sums.append(np.sum(attributions_list, axis=0))
305313
collect_means.append(np.mean(attributions_list, axis=0))
306314

307-
source_list = list(gene_names)
308-
target_names = [gene_names[g]] *len(gene_names)
309-
edge_names = [f'{s}_{t}' for s,t in zip(source_list, target_names)]
310-
315+
# 2. Convert the column-chunk to a PyArrow Table
316+
# Map the numpy chunk to the specific column names for this group 'g'
317+
current_col_names = column_names[g*cols : (g+1)*cols]
311318

312-
table = pa.table({
313-
edge_names[i]: attributions_list[:, i]
314-
for i in range(attributions_list.shape[1])
315-
})
319+
# We create a table where each slice of the numpy array is a column
320+
chunk_table = pa.table(
321+
[attributions_list[:, i] for i in range(attributions_list.shape[1])],
322+
names=current_col_names
323+
)
316324

317-
writer.write_table(table)
325+
# 3. Write this specific column-group to a Parquet file
326+
# In a dataset, these will be "sharded" columns
327+
file_path = os.path.join(output_dir, f"{gene_names[i]}.parquet")
328+
pq.write_table(chunk_table, file_path)
318329

319-
writer.close()
320330

321331

322332
else:

0 commit comments

Comments
 (0)