Skip to content

Commit ebe3c41

Browse files
add disk backing for inference and delete obsolete raw parameter form interfact
1 parent c26f7e5 commit ebe3c41

1 file changed

Lines changed: 62 additions & 18 deletions

File tree

src/netmap/grn/inferrence.py

Lines changed: 62 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def attribution_one_target(
221221
return attributions_list
222222

223223

224-
def inferrence(models, data_train_full_tensor, gene_names, xai_method='GradientShap', background_type = 'zeros', raw=False):
224+
def inferrence(models, data_train_full_tensor, gene_names, xai_method='GradientShap', background_type = 'zeros', backing_file='grn_adata.h5', return_in_memory=False):
225225

226226
"""
227227
The main inferrence function to compute the entire GRN. Computes all
@@ -255,40 +255,84 @@ def inferrence(models, data_train_full_tensor, gene_names, xai_method='GradientS
255255

256256
for trained_model in models:
257257
trained_model.forward_mu_only = True
258-
explainer, xai_type = _get_explainer(trained_model, xai_method, raw=raw)
258+
explainer, xai_type = _get_explainer(trained_model, xai_method, raw=False)
259259
tms.append(explainer)
260260

261261
attributions = []
262262

263-
for g in tqdm(range(data_train_full_tensor.shape[1])):
264-
attributions_list = attribution_one_target(
265-
g,
266-
tms,
267-
data_train_full_tensor,
268-
xai_type=xai_method,
269-
background_type= background_type)
263+
rows = data_train_full_tensor.shape[0]
264+
cols = data_train_full_tensor.shape[1]
265+
cols_grn = cols*cols
270266

271-
272-
attributions_list = aggregate_attributions(attributions_list, strategy='mean')
273-
attributions.append(attributions_list)
267+
if backing_file is not None:
268+
with h5py.File(backing_file, 'w') as f:
269+
270+
dset = f.create_dataset(
271+
'data',
272+
shape=(rows, cols_grn),
273+
dtype='float32',
274+
chunks=(rows, cols)
275+
)
276+
277+
for g in tqdm(range(data_train_full_tensor.shape[1])):
278+
attributions_list = attribution_one_target(
279+
g,
280+
tms,
281+
data_train_full_tensor,
282+
xai_type=xai_type,
283+
background_type= background_type)
284+
285+
286+
287+
attributions_list = aggregate_attributions(attributions_list, strategy='mean')
288+
dset[:, (g*cols): ((g+1)*cols)] = attributions_list
274289

275-
## AGGREGATION: REPLACE LIST BY AGGREGATED DATA
276-
for i in range(len(attributions)):
290+
else:
291+
for g in tqdm(range(data_train_full_tensor.shape[1])):
292+
attributions_list = attribution_one_target(
293+
g,
294+
tms,
295+
data_train_full_tensor,
296+
xai_type=xai_type,
297+
background_type= background_type)
298+
299+
300+
301+
attributions_list = aggregate_attributions(attributions_list, strategy='mean')
302+
attributions.append(attributions_list)
303+
304+
attributions = np.hstack(attributions)
277305

306+
for i in range(cols):
278307
## Create name vector
279308
name_list = name_list + list(gene_names)
280309
target_names = target_names+[gene_names[i]] *len(gene_names)
281310

282-
283-
284-
attributions = np.hstack(attributions)
285311

286312
index_list = [f"{s}_{t}" for (s, t) in zip(name_list, target_names)]
287313
cou = pd.DataFrame({'index': index_list, 'source':name_list, 'target':target_names})
288314
cou = cou.set_index('index')
289315

290-
grn_adata = attribution_to_anndata(attributions, var=cou)
316+
if backing_file is not None:
317+
if return_in_memory:
318+
with h5py.File(backing_file, 'r+') as f:
319+
dset = f['data']
320+
grn_adata = ad.AnnData(dset, uns = {'backing_file': backing_file}, var = cou)
321+
grn_adata = grn_adata.to_memory()
322+
323+
else:
324+
grn_adata = ad.AnnData(shape = (rows, cols_grn), uns = {'backing_file': backing_file}, var = cou)
325+
else:
326+
grn_adata = attribution_to_anndata(attributions, var=cou)
327+
328+
return grn_adata
329+
291330

331+
def return_grn_adata_to_memory(grn_adata):
332+
with h5py.File(grn_adata.uns['backing_file'], 'r+') as f:
333+
dset = f['data']
334+
grn_adata.X = dset
335+
grn_adata = grn_adata.to_memory()
292336
return grn_adata
293337

294338

0 commit comments

Comments
 (0)