Skip to content

Commit 5f8f89a

Browse files
authored
Merge pull request #416 from tejassinghbhati/fix-empty-session-crash-415
fix: gracefully handle missing active session on download routes (fixes #415)
2 parents 4f77c85 + cdba5d1 commit 5f8f89a

3 files changed

Lines changed: 76 additions & 0 deletions

File tree

API/Routes/Case/CaseRoute.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,8 @@ def prepareCSV():
449449
def downloadCSV():
450450
try:
451451
casename = session.get('osycase', None)
452+
if casename is None:
453+
return jsonify({'message': 'No active session. Please select a model first.', 'status_code': 'error'}), 400
452454
dataFile = Path(Config.DATA_STORAGE,casename,'export.csv')
453455

454456
dir = Path(Config.DATA_STORAGE,casename)

API/Routes/DataFile/DataFileRoute.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,11 @@ def downloadDataFile():
210210
# return jsonify(response), 200
211211
#path = "/Examples.pdf"
212212
case = session.get('osycase', None)
213+
if case is None:
214+
return jsonify({'message': 'No active session. Please select a model first.', 'status_code': 'error'}), 400
213215
caserunname = request.args.get('caserunname')
216+
if not caserunname:
217+
return jsonify({'message': 'Missing required parameter: caserunname.', 'status_code': 'error'}), 400
214218
Config.validate_path(Config.DATA_STORAGE, os.path.join(case or '', 'res', caserunname or ''))
215219
dataFile = Path(Config.DATA_STORAGE,case, 'res',caserunname, 'data.txt')
216220
return send_file(dataFile.resolve(), as_attachment=True, max_age=0)
@@ -224,7 +228,11 @@ def downloadDataFile():
224228
def downloadFile():
225229
try:
226230
case = session.get('osycase', None)
231+
if case is None:
232+
return jsonify({'message': 'No active session. Please select a model first.', 'status_code': 'error'}), 400
227233
file = request.args.get('file')
234+
if not file:
235+
return jsonify({'message': 'Missing required parameter: file.', 'status_code': 'error'}), 400
228236
Config.validate_path(Config.DATA_STORAGE, os.path.join(case or '', 'res', 'csv', file or ''))
229237
dataFile = Path(Config.DATA_STORAGE,case,'res','csv',file)
230238
return send_file(dataFile.resolve(), as_attachment=True, max_age=0)
@@ -238,8 +246,14 @@ def downloadFile():
238246
def downloadCSVFile():
239247
try:
240248
case = session.get('osycase', None)
249+
if case is None:
250+
return jsonify({'message': 'No active session. Please select a model first.', 'status_code': 'error'}), 400
241251
file = request.args.get('file')
242252
caserunname = request.args.get('caserunname')
253+
if not file:
254+
return jsonify({'message': 'Missing required parameter: file.', 'status_code': 'error'}), 400
255+
if not caserunname:
256+
return jsonify({'message': 'Missing required parameter: caserunname.', 'status_code': 'error'}), 400
243257
Config.validate_path(Config.DATA_STORAGE, os.path.join(case or '', 'res', caserunname or '', 'csv', file or ''))
244258
dataFile = Path(Config.DATA_STORAGE,case,'res',caserunname,'csv',file)
245259
return send_file(dataFile.resolve(), as_attachment=True, max_age=0)
@@ -253,7 +267,11 @@ def downloadCSVFile():
253267
def downloadResultsFile():
254268
try:
255269
case = session.get('osycase', None)
270+
if case is None:
271+
return jsonify({'message': 'No active session. Please select a model first.', 'status_code': 'error'}), 400
256272
caserunname = request.args.get('caserunname')
273+
if not caserunname:
274+
return jsonify({'message': 'Missing required parameter: caserunname.', 'status_code': 'error'}), 400
257275
Config.validate_path(Config.DATA_STORAGE, os.path.join(case or '', 'res', caserunname or ''))
258276
dataFile = Path(Config.DATA_STORAGE,case, 'res', caserunname,'results.txt')
259277
return send_file(dataFile.resolve(), as_attachment=True, max_age=0)

tests/test_app_smoke.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,5 +122,61 @@ def test_repo_has_no_unmerged_paths(self):
122122
self.assertEqual(result.stdout.strip(), "")
123123

124124

125+
class DownloadRouteGuardTests(unittest.TestCase):
126+
@classmethod
127+
def setUpClass(cls):
128+
sys.path.insert(0, str(API_DIR))
129+
os.environ.setdefault("MUIOGO_SECRET_KEY", "smoke-test-secret")
130+
cls.app_module = importlib.import_module("app")
131+
132+
def setUp(self):
133+
self.client = self.app_module.app.test_client()
134+
135+
def test_download_routes_require_active_session(self):
136+
endpoints = [
137+
("/downloadDataFile", {"caserunname": "run1"}),
138+
("/downloadFile", {"file": "result.csv"}),
139+
("/downloadCSVFile", {"file": "result.csv", "caserunname": "run1"}),
140+
("/downloadResultsFile", {"caserunname": "run1"}),
141+
("/downloadCSV", {}),
142+
]
143+
144+
for path, query in endpoints:
145+
with self.subTest(path=path):
146+
response = self.client.get(path, query_string=query)
147+
self.assertEqual(response.status_code, 400)
148+
self.assertEqual(
149+
response.get_json(),
150+
{
151+
"message": "No active session. Please select a model first.",
152+
"status_code": "error",
153+
},
154+
)
155+
156+
def test_download_routes_require_query_params(self):
157+
with self.client.session_transaction() as session_data:
158+
session_data["osycase"] = "demo"
159+
160+
cases = [
161+
("/downloadDataFile", {}, "Missing required parameter: caserunname."),
162+
("/downloadFile", {}, "Missing required parameter: file."),
163+
("/downloadCSVFile", {"caserunname": "run1"}, "Missing required parameter: file."),
164+
("/downloadCSVFile", {"file": "result.csv"}, "Missing required parameter: caserunname."),
165+
("/downloadResultsFile", {}, "Missing required parameter: caserunname."),
166+
]
167+
168+
for path, query, message in cases:
169+
with self.subTest(path=path, query=query):
170+
response = self.client.get(path, query_string=query)
171+
self.assertEqual(response.status_code, 400)
172+
self.assertEqual(
173+
response.get_json(),
174+
{
175+
"message": message,
176+
"status_code": "error",
177+
},
178+
)
179+
180+
125181
if __name__ == "__main__":
126182
unittest.main()

0 commit comments

Comments
 (0)