-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathnmf_workflow.py
More file actions
186 lines (155 loc) · 8.75 KB
/
nmf_workflow.py
File metadata and controls
186 lines (155 loc) · 8.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
# nmf_workflow.py
import os
import json
import shutil
import numpy as np
import pandas as pd
from sklearn.decomposition import NMF
from scipy.sparse import issparse
from joblib import Parallel, delayed
from tqdm.auto import tqdm
from nmf_evaluation import calculate_nmf_evaluation_metrics
from nmf_plotting import plot_k_selection_results, plot_f1_gradient
def _run_nmf_for_single_k(k_val, X_input_subset, nmf_random_state, nmf_max_iter,
output_weights_dir_for_k, original_matrix_for_eval,
thresholds_eval, tmp_results_dir):
"""Helper function to run NMF for a single k and save its result immediately."""
try:
# Using the default scikit-learn 'init' method, as intended by your config.
model = NMF(n_components=k_val,
random_state=nmf_random_state,
max_iter=nmf_max_iter)
H_matrix = model.fit_transform(X_input_subset)
W_matrix = model.components_
os.makedirs(output_weights_dir_for_k, exist_ok=True)
np.save(os.path.join(output_weights_dir_for_k, f"W_k{k_val}.npy"), W_matrix.T)
np.save(os.path.join(output_weights_dir_for_k, f"H_k{k_val}.npy"), H_matrix)
R_reconstructed = H_matrix @ W_matrix
# Pass the original matrix (which can be sparse) directly to the evaluation function
max_f1, auprc = calculate_nmf_evaluation_metrics(original_matrix_for_eval, R_reconstructed, thresholds_eval)
# FIX: Convert numpy floats to standard Python floats for JSON serialization.
result_dict = {
'k': k_val,
'max_mean_f1': float(max_f1),
'auprc': float(auprc),
'reconstruction_error': float(model.reconstruction_err_),
'nmf_params': {
'n_components': k_val,
'random_state': nmf_random_state,
'max_iter': nmf_max_iter,
'solver': model.solver,
'beta_loss': model.beta_loss,
'init': model.init if model.init is not None else 'auto (sklearn default)',
'tol': model.tol
}
}
with open(os.path.join(tmp_results_dir, f"result_k_{k_val}.json"), 'w') as f:
json.dump(result_dict, f, indent=2)
return result_dict
except Exception as e:
print(f" ERROR in NMF for k={k_val}: {e}")
return None
def _execute_nmf_loop(X_nmf_input, k_range, nmf_random_state, nmf_max_iter_val, thresholds,
base_output_dir, group_prefix, n_jobs):
"""Shared logic for running the NMF loop, plotting, and returning results."""
tmp_results_dir = os.path.join(base_output_dir, "tmp_results")
os.makedirs(tmp_results_dir, exist_ok=True)
# FIX: The line that caused the memory crash by calling .toarray() is removed.
tasks = []
for k_val_iter in k_range:
k_weights_dir = os.path.join(base_output_dir, f"{k_val_iter}NMF", "weights")
# We now pass the original (and potentially sparse) X_nmf_input for evaluation.
tasks.append(delayed(_run_nmf_for_single_k)(
k_val_iter, X_nmf_input, nmf_random_state, nmf_max_iter_val, k_weights_dir,
X_nmf_input, thresholds, tmp_results_dir
))
effective_n_jobs = os.cpu_count() if n_jobs == -1 else n_jobs
print(f" INFO: Running NMF for k in {list(k_range)} using up to {effective_n_jobs} parallel jobs.")
print(f" INFO: NMF Parameters: random_state={nmf_random_state}, max_iter={nmf_max_iter_val}, other params: scikit-learn defaults.")
# FIX: Correctly integrate tqdm with joblib to prevent the initial TypeError.
results_from_parallel_raw = Parallel(n_jobs=n_jobs)(
task for task in tqdm(tasks, desc=f" Processing k values for {group_prefix}")
)
all_results = [res for res in results_from_parallel_raw if res is not None]
if not all_results:
print(f" WARNING: No NMF results were successfully obtained for {group_prefix}. Check for errors.")
if os.path.exists(tmp_results_dir): shutil.rmtree(tmp_results_dir)
return None
results_df = pd.DataFrame(all_results).sort_values(by='k').reset_index(drop=True)
csv_path = os.path.join(base_output_dir, f"{group_prefix}_nmf_evaluation_summary.csv")
results_df.to_csv(csv_path, index=False)
print(f" INFO: Evaluation summary for {group_prefix} saved to: {csv_path}")
run_params_log = {
'group': group_prefix,
'k_range': list(k_range),
'nmf_random_state': nmf_random_state,
'nmf_max_iter': nmf_max_iter_val,
'thresholds_count': len(thresholds),
'input_matrix_shape': X_nmf_input.shape,
'sklearn_nmf_defaults_note': "NMF uses scikit-learn defaults for init, solver, beta_loss, tol unless overridden."
}
with open(os.path.join(base_output_dir, f"{group_prefix}_nmf_run_parameters.json"), 'w') as f:
json.dump(run_params_log, f, indent=2)
print(f" INFO: Overall NMF run parameters logged for {group_prefix}.")
print(f" INFO: Generating plots for {group_prefix}...")
plot_k_selection_results(results_df, base_output_dir, group_prefix)
plot_f1_gradient(results_df, base_output_dir, group_prefix)
if os.path.exists(tmp_results_dir): shutil.rmtree(tmp_results_dir)
print(f"=== Finished NMF Pipeline for: {group_prefix} ===")
return results_df
def run_nmf_pipeline_for_group(
group_info, all_samples_list, sample_to_cancer_type_map,
bool_map_overall, k_range, nmf_random_state, nmf_max_iter_val, thresholds,
global_output_dir, n_jobs
):
"""Subsets data for a group and runs the NMF pipeline."""
group_name = group_info['group_name']
group_prefix = group_name[:3].capitalize() if group_name.lower() != 'scatac' else 'scATAC'
min_k, max_k = min(k_range), max(k_range)
print(f"\n=== Preparing NMF Pipeline for Group: {group_name} ({group_prefix}) ===")
group_base_output_dir = os.path.join(global_output_dir, f"{group_prefix}_NMF_K_opt_{min_k}_{max_k}")
group_cancer_codes = set(group_info['cancer_codes'])
if group_name.lower() == 'scatac':
print(" INFO: Filtering for scATAC samples based on sample ID prefix.")
group_sample_indices = [
idx for idx, sample_id in enumerate(all_samples_list)
if sample_to_cancer_type_map.get(sample_id, "").startswith('scATAC_')
]
selected_samples = [all_samples_list[idx] for idx in group_sample_indices]
else:
print(f" INFO: Filtering for TCGA samples with codes: {group_cancer_codes}")
group_sample_indices = [
idx for idx, sample_id in enumerate(all_samples_list)
if (c_type := sample_to_cancer_type_map.get(sample_id)) and c_type[:4] in group_cancer_codes
]
selected_samples = [all_samples_list[idx] for idx in group_sample_indices]
# Export selected sample IDs to preprocessed_data
export_path = os.path.join("preprocessed_data", f"selected_samples_{group_name}.json")
with open(export_path, "w") as f:
json.dump(selected_samples, f, indent=2)
print(f" INFO: Exported selected sample IDs to {export_path}")
if not group_sample_indices:
print(f" WARNING: No samples found for {group_name}. Skipping.")
return None
print(f" INFO: Found {len(group_sample_indices)} samples for group {group_name}.")
X_nmf_input_group = bool_map_overall[:, group_sample_indices].T.astype(np.float32)
if X_nmf_input_group.shape[0] < min(k_range):
print(f" WARNING: Not enough samples ({X_nmf_input_group.shape[0]}) for min_k={min(k_range)}. Skipping.")
return None
print(f" INFO: NMF input matrix for {group_name} (samples x features): {X_nmf_input_group.shape}")
return _execute_nmf_loop(
X_nmf_input_group, k_range, nmf_random_state, nmf_max_iter_val, thresholds,
group_base_output_dir, group_prefix, n_jobs
)
def run_nmf_pipeline_for_all_samples(bool_map_overall, k_range, nmf_random_state, nmf_max_iter_val, thresholds,
global_output_dir, n_jobs):
"""Runs the NMF pipeline on the full dataset without subsetting."""
group_prefix = "AllSamples"
min_k, max_k = min(k_range), max(k_range)
print(f"\n=== Preparing NMF Pipeline for ALL SAMPLES ({group_prefix}) ===")
base_output_dir = os.path.join(global_output_dir, f"{group_prefix}_NMF_K_opt_{min_k}_{max_k}")
# Using float32 is a good robustness measure to halve memory usage during NMF.
X_nmf_input = bool_map_overall.T.astype(np.float32)
print(f" INFO: NMF input matrix for All Samples (samples x features): {X_nmf_input.shape}")
return _execute_nmf_loop(X_nmf_input, k_range, nmf_random_state, nmf_max_iter_val, thresholds,
base_output_dir, group_prefix, n_jobs)