diff --git a/API/Routes/Case/CaseRoute.py b/API/Routes/Case/CaseRoute.py index 186f467d..a3684cd6 100644 --- a/API/Routes/Case/CaseRoute.py +++ b/API/Routes/Case/CaseRoute.py @@ -449,6 +449,8 @@ def prepareCSV(): def downloadCSV(): try: casename = session.get('osycase', None) + if casename is None: + return jsonify({'message': 'No active session. Please select a model first.', 'status_code': 'error'}), 400 dataFile = Path(Config.DATA_STORAGE,casename,'export.csv') dir = Path(Config.DATA_STORAGE,casename) diff --git a/API/Routes/DataFile/DataFileRoute.py b/API/Routes/DataFile/DataFileRoute.py index 7dbbe881..56b04efd 100644 --- a/API/Routes/DataFile/DataFileRoute.py +++ b/API/Routes/DataFile/DataFileRoute.py @@ -210,7 +210,11 @@ def downloadDataFile(): # return jsonify(response), 200 #path = "/Examples.pdf" case = session.get('osycase', None) + if case is None: + return jsonify({'message': 'No active session. Please select a model first.', 'status_code': 'error'}), 400 caserunname = request.args.get('caserunname') + if not caserunname: + return jsonify({'message': 'Missing required parameter: caserunname.', 'status_code': 'error'}), 400 Config.validate_path(Config.DATA_STORAGE, os.path.join(case or '', 'res', caserunname or '')) dataFile = Path(Config.DATA_STORAGE,case, 'res',caserunname, 'data.txt') return send_file(dataFile.resolve(), as_attachment=True, max_age=0) @@ -224,7 +228,11 @@ def downloadDataFile(): def downloadFile(): try: case = session.get('osycase', None) + if case is None: + return jsonify({'message': 'No active session. Please select a model first.', 'status_code': 'error'}), 400 file = request.args.get('file') + if not file: + return jsonify({'message': 'Missing required parameter: file.', 'status_code': 'error'}), 400 Config.validate_path(Config.DATA_STORAGE, os.path.join(case or '', 'res', 'csv', file or '')) dataFile = Path(Config.DATA_STORAGE,case,'res','csv',file) return send_file(dataFile.resolve(), as_attachment=True, max_age=0) @@ -238,8 +246,14 @@ def downloadFile(): def downloadCSVFile(): try: case = session.get('osycase', None) + if case is None: + return jsonify({'message': 'No active session. Please select a model first.', 'status_code': 'error'}), 400 file = request.args.get('file') caserunname = request.args.get('caserunname') + if not file: + return jsonify({'message': 'Missing required parameter: file.', 'status_code': 'error'}), 400 + if not caserunname: + return jsonify({'message': 'Missing required parameter: caserunname.', 'status_code': 'error'}), 400 Config.validate_path(Config.DATA_STORAGE, os.path.join(case or '', 'res', caserunname or '', 'csv', file or '')) dataFile = Path(Config.DATA_STORAGE,case,'res',caserunname,'csv',file) return send_file(dataFile.resolve(), as_attachment=True, max_age=0) @@ -253,7 +267,11 @@ def downloadCSVFile(): def downloadResultsFile(): try: case = session.get('osycase', None) + if case is None: + return jsonify({'message': 'No active session. Please select a model first.', 'status_code': 'error'}), 400 caserunname = request.args.get('caserunname') + if not caserunname: + return jsonify({'message': 'Missing required parameter: caserunname.', 'status_code': 'error'}), 400 Config.validate_path(Config.DATA_STORAGE, os.path.join(case or '', 'res', caserunname or '')) dataFile = Path(Config.DATA_STORAGE,case, 'res', caserunname,'results.txt') return send_file(dataFile.resolve(), as_attachment=True, max_age=0) diff --git a/tests/test_app_smoke.py b/tests/test_app_smoke.py index a77696bf..3c2750fb 100644 --- a/tests/test_app_smoke.py +++ b/tests/test_app_smoke.py @@ -122,5 +122,61 @@ def test_repo_has_no_unmerged_paths(self): self.assertEqual(result.stdout.strip(), "") +class DownloadRouteGuardTests(unittest.TestCase): + @classmethod + def setUpClass(cls): + sys.path.insert(0, str(API_DIR)) + os.environ.setdefault("MUIOGO_SECRET_KEY", "smoke-test-secret") + cls.app_module = importlib.import_module("app") + + def setUp(self): + self.client = self.app_module.app.test_client() + + def test_download_routes_require_active_session(self): + endpoints = [ + ("/downloadDataFile", {"caserunname": "run1"}), + ("/downloadFile", {"file": "result.csv"}), + ("/downloadCSVFile", {"file": "result.csv", "caserunname": "run1"}), + ("/downloadResultsFile", {"caserunname": "run1"}), + ("/downloadCSV", {}), + ] + + for path, query in endpoints: + with self.subTest(path=path): + response = self.client.get(path, query_string=query) + self.assertEqual(response.status_code, 400) + self.assertEqual( + response.get_json(), + { + "message": "No active session. Please select a model first.", + "status_code": "error", + }, + ) + + def test_download_routes_require_query_params(self): + with self.client.session_transaction() as session_data: + session_data["osycase"] = "demo" + + cases = [ + ("/downloadDataFile", {}, "Missing required parameter: caserunname."), + ("/downloadFile", {}, "Missing required parameter: file."), + ("/downloadCSVFile", {"caserunname": "run1"}, "Missing required parameter: file."), + ("/downloadCSVFile", {"file": "result.csv"}, "Missing required parameter: caserunname."), + ("/downloadResultsFile", {}, "Missing required parameter: caserunname."), + ] + + for path, query, message in cases: + with self.subTest(path=path, query=query): + response = self.client.get(path, query_string=query) + self.assertEqual(response.status_code, 400) + self.assertEqual( + response.get_json(), + { + "message": message, + "status_code": "error", + }, + ) + + if __name__ == "__main__": unittest.main()