forked from FraggeGaming/ThesisInferenceServer
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTestServer.py
More file actions
362 lines (284 loc) · 12.1 KB
/
TestServer.py
File metadata and controls
362 lines (284 loc) · 12.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
import os
import sys
import json
import csv
import subprocess
import nibabel as nib
from flask import Flask, request, send_file, jsonify
from dataclasses import dataclass, asdict
from subprocess import Popen
import threading
import platform
import signal
from typing import List, Dict
from datetime import datetime
import traceback
app = Flask(__name__)
#Dataclass for providing the models
@dataclass
class AIModel:
id: int
title: str
description: str
inputModality: str
outputModality: str
region: str
@dataclass
class ModelWithPath:
model: AIModel
modelPath: str
networkName: str
# Base directory (handles normal + PyInstaller executable)
BASE_DIR = getattr(sys, '_MEIPASS', os.path.abspath(os.path.dirname(__file__)))
# Paths relative to the base directory
DATA_PATH = os.path.join(BASE_DIR, "codice_curriculum") #Path to the inference code
TEST_SCRIPT = os.path.join(DATA_PATH, "test_interface.py") #Inference script
CHECKPOINTS_DIR = os.path.join(BASE_DIR, "checkpoints") #Where all added switches, or models lay
OUTPUT_DIR = os.path.join(BASE_DIR, "output") #Directory to store the generated files
#Directory to fetch and save the files that are to be generated.
UPLOAD_DIR = os.path.join(BASE_DIR, "uploads") #Obs these files gets deleted when user downloads the output
LOG_DIR = os.path.join(BASE_DIR, "Logs")
# Ensure necessary folders exist
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
os.makedirs(LOG_DIR, exist_ok=True)
#the running processes and its lock is for keeping track of the process so it can be discarded if the user wants to cancel
running_processes = {} #save the running process that runs the model
lock = threading.Lock() #lock for saving the process
progress_state = {} #Dict for accessing the latest progress update for each running job_id
progress_lock = threading.Lock()
def write_error_log(error_text: str, log_dir: str = LOG_DIR) -> str:
"""
Writes an error log to a uniquely named file inside the specified log directory.
Args:
error_text (str): The error message or traceback to log.
log_dir (str): Path to the logs directory. Defaults to global LOG_DIR.
Returns:
str: The path to the written log file.
"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_filename = f"error_{timestamp}.txt"
log_path = os.path.join(log_dir, log_filename)
with open(log_path, "w") as log_file:
log_file.write(error_text)
return log_path
def read_existing_models(path: str = "models.json") -> Dict[int, ModelWithPath]:
with open(path, "r") as file:
data = json.load(file)
model_dict = {}
for item in data:
model_path = item.pop("modelPath")
network_name = item.pop("networkName")
model = AIModel(**item)
model_dict[model.id] = ModelWithPath(model=model, modelPath=model_path, networkName=network_name)
return model_dict
def get_model_path_by_id(model_id: int) -> str:
models = read_existing_models()
return models.get(model_id).modelPath if model_id in models else None
def get_network_name_by_id(model_id: int, path: str = "models.json") -> str:
models = read_existing_models(path)
return models.get(model_id).networkName if model_id in models else None
@app.route("/modalities", methods=["GET"])
def get_modalities():
try:
fetchedModels = read_existing_models()
print(fetchedModels)
modalities = sorted(set(wrapper.model.inputModality for wrapper in fetchedModels.values()))
return jsonify(modalities)
except Exception as e:
error_text = "Error in /modalities endpoint:\n\n" + traceback.format_exc()
log_path = write_error_log(error_text)
return jsonify({"error": str(e)}), 500
@app.route("/regions", methods=["GET"])
def get_regions():
try:
fetchedModels = read_existing_models()
regions = sorted(set(wrapper.model.region for wrapper in fetchedModels.values()))
return jsonify(regions)
except Exception as e:
error_text = "Error in /regions endpoint:\n\n" + traceback.format_exc()
log_path = write_error_log(error_text)
return jsonify({"error": str(e)}), 500
#Function to kill the subprocess
def kill_subprocess(process):
if platform.system() == "Windows":
print("[Server] Sending CTRL_BREAK_EVENT to Windows process")
process.send_signal(signal.CTRL_BREAK_EVENT)
else:
print("[Server] Killing process group on Unix")
os.killpg(os.getpgid(process.pid), signal.SIGKILL)
process.wait()
#Get call to fetch the compatible models for the UI based on region, modality etc. Returns the compatible models back to the UI
@app.route("/getmodels", methods=["POST"])
def get_models():
print("Received POST request to /getmodels")
try:
data = request.get_json()
print("Raw JSON received:", data)
if not data:
print("No JSON data received.")
return jsonify({"error": "Missing JSON data"}), 400
modality = data.get("modality")
region = data.get("region")
print(f"Filtering models for modality='{modality}', region='{region}'")
if not modality or not region:
print("Missing modality or region in the request.")
return jsonify({"error": "Missing modality or region"}), 400
fetchedModels = read_existing_models()
filtered_models = [
asdict(wrapper.model)
for wrapper in fetchedModels.values()
if wrapper.model.inputModality == modality and wrapper.model.region == region
]
print(f"Found {len(filtered_models)} matching models.")
for model in filtered_models:
print("Matched model:", model)
return jsonify(filtered_models)
except Exception as e:
error_text = "Error in /getmodels endpoint:\n\n" + traceback.format_exc()
log_path = write_error_log(error_text)
print("Exception occurred while handling /getmodels:", str(e))
return jsonify({"error": str(e)}), 500
#function for reading the process stdout to get progress update from the inference. Saving the progress inside progress_state dict for the said job_id
def read_progress(job_id, process):
marker = "::PROGRESS::"
output_lines = []
for raw_line in process.stdout:
line = raw_line.strip()
if not line:
continue
output_lines.append(line)
if marker in line:
try:
_, payload = line.split(marker, 1)
data = json.loads(payload.strip())
with progress_lock:
progress_state[job_id] = data
except Exception:
write_error_log(traceback.format_exc())
print(f"[{job_id} LOG]", line)
#log traceback if present
full_output = "\n".join(output_lines)
if "Traceback (most recent call last):" in full_output:
traceback_part = full_output.split("Traceback (most recent call last):", 1)[-1]
write_error_log(f"[{job_id}] Traceback:\nTraceback (most recent call last):{traceback_part}")
with progress_lock:
progress_state[job_id]["error"] = True
progress_state[job_id]["status"] = traceback_part #Remove this and just send "model failed" if you dont want to specify filepaths etc
#Fetches the progress json gathered from the stdout of the inference
@app.route("/progress/<job_id>")
def get_progress(job_id):
with progress_lock:
progress = progress_state.get(job_id)
if not progress:
return '', 204
return jsonify(progress)
#Fetches the nifti based on the job_id. The nifti needs to be saved as {job_id}.nii.gz in order for the function to find it
@app.route("/download/<job_id>", methods=["GET"])
def download_output(job_id):
nifti_output = os.path.join(OUTPUT_DIR, f"{job_id}.nii.gz")
img = nib.load(nifti_output)
data = img.get_fdata()
denormalized_data = data * 20
denorm_img = nib.Nifti1Image(denormalized_data, img.affine, img.header)
nib.save(denorm_img, nifti_output)
remove_uploaded_nifti(job_id)
if not os.path.exists(nifti_output):
error_text = "Error in /download endpoint:\n\n" + traceback.format_exc()
log_path = write_error_log(error_text)
return "Output file not ready yet.", 404
return send_file(nifti_output, mimetype="application/octet-stream", as_attachment=True)
#function to remove the uploaded nifti file. Gets called after the user has downloaded the generated nifti
def remove_uploaded_nifti(job_id):
path = os.path.join(UPLOAD_DIR, f"{job_id}.nii.gz")
if os.path.exists(path):
os.remove(path)
print(f"Deleted progress file: {path}")
else:
print(f"Progress file not found: {path}")
#Cancels the job if its running.
@app.route("/cancel/<job_id>", methods=["POST"])
def cancel_job(job_id):
print(f"[Server] Trying to cancel job: '{job_id}'")
print(f"[Server] Available running jobs: {list(running_processes.keys())}")
with lock:
process = running_processes.get(job_id)
if process:
kill_subprocess(process)
del running_processes[job_id]
remove_uploaded_nifti(job_id)
return jsonify({"status": "Cancelled"}), 200
else:
return jsonify({"error": "No such job running"}), 404
#Starts a subprocess and runs the model inference for the said model and the user added nifti
@app.route("/process", methods=["POST"])
def process_nifti():
uploaded_file = request.files.get("file")
metadata_json = request.form.get("metadata")
if not uploaded_file or not metadata_json:
return "Missing file or metadata", 400
metadata = json.loads(metadata_json)
print("Received metadata:", metadata)
# Save uploaded NIfTI file
upload_path = os.path.join(UPLOAD_DIR, f"{metadata['title']}.nii.gz")
uploaded_file.save(upload_path)
print(f"Saved uploaded file to: {upload_path}")
job_id = metadata['title']
mod = metadata['modality']
region = metadata['region']
model = metadata['model']
model_id = model.get("id")
model_folder_path = get_model_path_by_id(model_id)
print(f"Model path: {model_folder_path}")
which_epoch = get_network_name_by_id(model_id)
test_district = region
switch_path = os.path.join(CHECKPOINTS_DIR, model_folder_path)
print("Network name:", get_network_name_by_id(model_id))
command = [
"python", TEST_SCRIPT,
"--gpu_ids", "-1",
"--json_id", job_id,
"--dataroot", DATA_PATH,
"--test_district", test_district,
"--which_epoch", str(which_epoch),
"--out_path", OUTPUT_DIR,
"--upload_dir", upload_path,
"--checkpoints_dir", CHECKPOINTS_DIR,
"--switch_paths", switch_path
]
print(f"Running model command:\n{command}")
with progress_lock:
data = {
"step": 1,
"total": 1,
"job_id": job_id,
"finished": False,
"status": "Loading inference subprocess",
"error": False
}
progress_state[job_id] = data
try:
process = subprocess.Popen(
command,
shell=False,
stdout=subprocess.PIPE, #capture stdout so we can read it from the parent process
stderr=subprocess.STDOUT,
text=True,
creationflags=subprocess.CREATE_NEW_PROCESS_GROUP if platform.system() == "Windows" else 0,
preexec_fn=os.setsid if platform.system() != "Windows" else None,
)
threading.Thread(
target=read_progress,
args=(job_id, process),
daemon=True
).start()
running_processes[job_id] = process
return jsonify({"status": "Running model"}), 200
except Exception as e:
error_text = "Error in /process endpoint:\n\n" + traceback.format_exc()
log_path = write_error_log(error_text)
print(f"Error running command: {e}")
return f"Error starting model: {e}", 500
if __name__ == "__main__":
app.run(host="0.0.0.0", port=8000)