diff --git a/src/app.py b/src/app.py new file mode 100644 index 00000000..45bcd240 --- /dev/null +++ b/src/app.py @@ -0,0 +1,338 @@ + +import subprocess +from flask import Flask, render_template, request, redirect, url_for, session, send_from_directory +import os +import sys +import uuid +from flask_session import Session +from visualize_h5 import h5_visualization_route +from visualize_h5 import get_h5_plot # Import the function +import shutil +import os + + + +# Flask App Setup +app = Flask(__name__) + + + + +SESSION_DIR = './flask_session' +if os.path.exists(SESSION_DIR): + shutil.rmtree(SESSION_DIR) + os.makedirs(SESSION_DIR) + +app.config['SECRET_KEY'] = os.environ.get('FLASK_SECRET_KEY', 'dev_secret_key') +app.config['SESSION_TYPE'] = 'filesystem' +Session(app) + +app.register_blueprint(h5_visualization_route) + +UPLOAD_FOLDER = os.path.join('../regr_smlp/code/uploads') +os.makedirs(UPLOAD_FOLDER, exist_ok=True) +app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER + + + +def call_smlp_api( argument_list): + """ + mode: 'train', 'predict', 'verify', 'optimize', etc. + argument_list: e.g. ["-data", "my_dataset", "-resp", "y1,y2", ...] + """ + + cmd_string = " ".join(argument_list) + + + cwd_dir = os.path.abspath("../regr_smlp/code/") + try: + result = subprocess.run( + argument_list, + capture_output=True, + text=True, + cwd=cwd_dir + ) + full_output = f"STDOUT:\n{result.stdout}\nSTDERR:\n{result.stderr}" + print("DEBUG: Command output ->", full_output) + return result.stdout.strip() if result.returncode == 0 else full_output + + except Exception as e: + return f"Error calling SMLP: {str(e)}" + + +# HOME +@app.route('/') +def home(): + return render_template('index.html') + +# TRAIN +@app.route('/train', methods=['GET', 'POST']) +def train(): + if request.method == 'POST': + data_file = request.files.get('data_file') + dataset_path = None + + if data_file and data_file.filename: + dataset_path = os.path.join(UPLOAD_FOLDER, f"{uuid.uuid4()}_{data_file.filename}") + data_file.save(dataset_path) + + def add_arg(flag, value): + if value is not None and value != "": + arguments.extend([flag, str(value)]) + + arguments = ['python3', os.path.abspath("../src/run_smlp.py")] + + if not dataset_path : + return "Error: Missing required dataset or spec file", 400 + + # Required arguments + + add_arg("-data", os.path.abspath(dataset_path)) + add_arg("-out_dir", request.form.get('out_dir_val', './')) + add_arg("-pref", request.form.get('pref_val', 'TestTrain')) + add_arg("-mode", "train") + add_arg("-model", request.form.get('model')) + add_arg("-dt_sklearn_max_depth", request.form.get('dt_sklearn_max_depth')) + add_arg("-mrmr_pred", request.form.get('mrmr_pred')) + add_arg("-resp", request.form.get('resp')) + add_arg("-feat", request.form.get('feat')) + add_arg("-save_model", request.form.get('save_model')) + add_arg("-model_name", request.form.get('model_name')) + add_arg("-scale_feat", request.form.get('scale_feat')) + add_arg("-scale_resp", request.form.get('scale_resp')) + add_arg("-dt_sklearn_max_depth", request.form.get('dt_sklearn_max_depth')) + add_arg("-train_split", request.form.get('train_split')) + add_arg("-seed", request.form.get('seed_val')) + add_arg("-plots", request.form.get('plots')) + + additional_command = request.form.get('additional_command') + if additional_command: + arguments.extend(additional_command.split()) + + # Debugging + print("DEBUG: Final SMLP Command ->", " ".join(arguments)) + + output = call_smlp_api(arguments) + session['output'] = output + return redirect(url_for('results')) + + return render_template('train.html') + + + +# PREDICT + +@app.route('/predict', methods=['GET', 'POST']) +def predict(): + if request.method == 'POST': + model_file = request.files.get('model_file') + new_data_file = request.files.get('new_data_file') + + model_path = None + newdata_path = None + + if model_file and model_file.filename: + model_path = os.path.join(UPLOAD_FOLDER, f"{uuid.uuid4()}_{model_file.filename}") + model_file.save(model_path) + + if new_data_file and new_data_file.filename: + newdata_path = os.path.join(UPLOAD_FOLDER, f"{uuid.uuid4()}_{new_data_file.filename}") + new_data_file.save(newdata_path) + + def add_arg(flag, value): + if value is not None and value != "": + arguments.extend([flag, str(value)]) + + arguments = ['python3', os.path.abspath("../src/run_smlp.py")] + + if not model_path or not newdata_path: + return "Error: Both model file and new data file are required", 400 + + # Required + add_arg("-mode", "predict") + + # Process paths (remove file extension) + add_arg("-model_name", os.path.abspath(model_path)) + add_arg("-new_data", os.path.abspath(newdata_path)) + + # Optional user inputs + add_arg("-out_dir", request.form.get('out_dir_val', './')) + add_arg("-pref", request.form.get('pref_val', 'PredictRun')) + add_arg("-log_time", request.form.get('log_time', 'f')) + add_arg("-plots", request.form.get('plots')) + add_arg("-save_model", request.form.get('save_model')) + add_arg("-model_name", request.form.get('model_name')) + add_arg("-seed", request.form.get('seed_val')) + + + additional_command = request.form.get('additional_command') + if additional_command: + arguments.extend(additional_command.split()) + + # Debug output + print("DEBUG: Final Predict SMLP Command ->", " ".join(arguments)) + + output = call_smlp_api(arguments) + session['output'] = output + return redirect(url_for('results')) + + return render_template('predict.html') + + + +def clear_old_plots(): + """Remove all previous plots before running a new exploration.""" + if os.path.exists(PLOT_SAVE_DIR): + for filename in os.listdir(PLOT_SAVE_DIR): + file_path = os.path.join(PLOT_SAVE_DIR, filename) + if os.path.isfile(file_path): + os.remove(file_path) + + +# EXPLORATION +@app.route('/explore', methods=['GET', 'POST']) +def explore(): + modes_list = ['certify', 'query', 'verify', 'synthesize', 'optimize', 'optsyn'] + + if request.method == 'POST': + clear_old_plots() + chosen_mode = request.form.get('explore_mode', '') + + if chosen_mode not in modes_list: + return "Error: Invalid mode selected", 400 + + data_file = request.files.get('data_file') + spec_file = request.files.get('spec_file') + + dataset_path = None + spec_path = None + + if data_file and data_file.filename: + dataset_path = os.path.join(UPLOAD_FOLDER, f"{uuid.uuid4()}_{data_file.filename}") + data_file.save(dataset_path) + + if spec_file and spec_file.filename: + spec_path = os.path.join(UPLOAD_FOLDER, f"{uuid.uuid4()}_{spec_file.filename}") + spec_file.save(spec_path) + + if not dataset_path or not spec_path: + return "Error: Missing required dataset or spec file", 400 + + arguments = ['python3', os.path.abspath("../src/run_smlp.py")] + + def add_arg(flag, value): + """Helper function to add arguments only if they are not empty.""" + if value is not None and value != "": + arguments.extend([flag, str(value)]) + + # Required arguments + add_arg("-data", os.path.abspath(dataset_path)) + add_arg("-out_dir", request.form.get('out_dir_val', './')) + add_arg("-pref", request.form.get('pref_val', 'Test113')) + add_arg("-mode", chosen_mode) + add_arg("-spec", os.path.abspath(spec_path)) + add_arg("-pareto", request.form.get('pareto')) + add_arg("-resp", request.form.get('resp_expr')) + add_arg("-feat", request.form.get('feat_expr')) + add_arg("-model", request.form.get('model_expr')) + add_arg("-dt_sklearn_max_depth", request.form.get('dt_sklearn_max_depth')) + add_arg("-mrmr_pred", request.form.get('mrmr_pred')) + add_arg("-epsilon", request.form.get('epsilon')) + add_arg("-delta_rel", request.form.get('delta_rel')) + add_arg("-save_model_config", request.form.get('save_model_config')) + add_arg("-plots", request.form.get('plots')) + add_arg("-log_time", request.form.get('log_time')) + add_arg("-seed", request.form.get('seed_val')) + add_arg("-objv_names", request.form.get('objv_names')) + add_arg("-objv_exprs", request.form.get('objv_exprs')) + + + additional_command = request.form.get('additional_command') + if additional_command: + arguments.extend(additional_command.split()) + + # checks if there should be a 3d plot + session['use_h5'] = any('nn_keras' in arg for arg in arguments) + session.modified = True + # Debugging + print("DEBUG: Final SMLP Command ->", " ".join(arguments)) + + output = call_smlp_api(arguments) + session['output'] = output + return redirect(url_for('results')) + + return render_template('exploration.html', modes=modes_list) + + +# DOE + +@app.route('/doe', methods=['GET', 'POST']) +def doe(): + if request.method == 'POST': + doe_spec_file = request.files.get('doe_spec_file') + spec_path = None + + if doe_spec_file and doe_spec_file.filename: + spec_path = os.path.join(UPLOAD_FOLDER, f"{uuid.uuid4()}_{doe_spec_file.filename}") + doe_spec_file.save(spec_path) + + def add_arg(flag, value): + if value is not None and value != "": + arguments.extend([flag, str(value)]) + + arguments = ['python3', os.path.abspath("../src/run_smlp.py")] + + if not spec_path: + return "Error: Missing DOE spec file", 400 + + # Required DOE mode + add_arg("-doe_spec", os.path.abspath(spec_path)) + add_arg("-out_dir", request.form.get('out_dir_val', './')) + add_arg("-pref", request.form.get('pref_val', 'TestDOE')) + add_arg("-mode", "doe") + add_arg("-doe_algo", request.form.get('doe_algo')) + add_arg("-log_time", request.form.get('log_time', 'f')) + + additional_command = request.form.get('additional_command') + if additional_command: + arguments.extend(additional_command.split()) + + # Debugging + print("DEBUG: Final DOE SMLP Command ->", " ".join(arguments)) + + output = call_smlp_api(arguments) + session['output'] = output + return redirect(url_for('results')) + + return render_template('doe.html') + + + +# results +PLOT_SAVE_DIR = os.path.abspath("../regr_smlp/code/images/results") +os.makedirs(PLOT_SAVE_DIR, exist_ok=True) # Ensure the directory is created + +@app.route('/results') +def results(): + output = session.get('output', 'No output available yet.') + print("\n--- RESULTS ---") + print(output) + + # Get list of available plot files + plots = [] + if os.path.exists(PLOT_SAVE_DIR): + plots = [f for f in os.listdir(PLOT_SAVE_DIR) if f.endswith(".png")] + + use_h5 = session.get('use_h5', False) + h5_plot_html = get_h5_plot() if use_h5 else None + + return render_template('results.html', output=output, plots=plots, h5_plot_html=h5_plot_html, use_h5=use_h5) + +# Route to serve images from the custom directory +@app.route('/plots/') +def serve_plot(filename): + return send_from_directory(PLOT_SAVE_DIR, filename) + +# MAIN +if __name__ == '__main__': + app.run(debug=False, use_reloader=False) diff --git a/src/readme_app.md b/src/readme_app.md new file mode 100644 index 00000000..11ca9cff --- /dev/null +++ b/src/readme_app.md @@ -0,0 +1,28 @@ +# About Flask + +This Flask app provides a front-end interface for the SMLP system. It allows you to: + +Train models on your own datasets. + +Predict using trained models on new data. + +Explore or validate model properties + +Perform DOE (Design of Experiments) to sample data in systematic ways. + + + + +## Requirements + + pip install flask flask_session + + + +## How to run + cd smlp/src/python3 app.py + + + +By default, Flask runs on http://127.0.0.1:5000 + diff --git a/src/smlp_py/smlp_plots.py b/src/smlp_py/smlp_plots.py index 4b356a80..34e0ab3e 100644 --- a/src/smlp_py/smlp_plots.py +++ b/src/smlp_py/smlp_plots.py @@ -56,16 +56,35 @@ def response_distribution_plot(out_dir, y, resp_names, interactive): sys.exit(1) +# STATIC_IMAGES_DIR = os.path.abspath(os.path.join(os.getcwd(), "static/images")) + +# os.makedirs(STATIC_IMAGES_DIR, exist_ok=True) + + +# def plot(name, interactive, out_prefix=None, **show_kws): +# #print('saved figure filename: ', 'out_prefix', out_prefix, 'name', name) +# #print('interactive', interactive); print('show_kws', show_kws) +# if out_prefix is not None: +# #print('Saving plot ' + out_prefix + '_' + name + '.png') +# plot_path = os.path.join(STATIC_IMAGES_DIR, f"{out_prefix}_{name}.png") +# plt.savefig(plot_path) +# # if interactive: +# # #print('HERE2', show_kws) +# # plt.show(**show_kws) +# # plt.clf() +PLOT_SAVE_DIR = os.path.abspath("../../regr_smlp/code/images") +# Ensure the directory exists +# os.makedirs(PLOT_SAVE_DIR , exist_ok=True) + def plot(name, interactive, out_prefix=None, **show_kws): - #print('saved figure filename: ', 'out_prefix', out_prefix, 'name', name) - #print('interactive', interactive); print('show_kws', show_kws) if out_prefix is not None: - #print('Saving plot ' + out_prefix + '_' + name + '.png') - plt.savefig(out_prefix + '_' + name + '.png') - if interactive: - #print('HERE2', show_kws) - plt.show(**show_kws) - plt.clf() + plot_path = os.path.join(PLOT_SAVE_DIR , f"{out_prefix}_{name}.png") + plt.savefig(plot_path) # Save in the correct location + + # if interactive: + # plt.show(**show_kws) + + # plt.clf() def plot_data_columns(data): diff --git a/src/static/css/style.css b/src/static/css/style.css new file mode 100644 index 00000000..1d3d824f --- /dev/null +++ b/src/static/css/style.css @@ -0,0 +1,41 @@ +body { + background-color: var(--bs-body-bg); + color: var(--bs-body-color); + margin-bottom: 50px; + min-height: 100vh; + transition: background-color 0.3s ease, color 0.3s ease; +} + +h2 { + margin-top: 1rem; + color: var(--bs-heading-color); +} + +pre { + white-space: pre-wrap; + background-color: var(--bs-secondary-bg); + color: var(--bs-body-color); + padding: 1rem; + border-radius: 0.375rem; + border: 1px solid var(--bs-border-color); +} + +.navbar-brand { + font-weight: bold; + text-transform: uppercase; +} + +.navbar[data-bs-theme="dark"] { + --bs-navbar-color: rgba(255, 255, 255, 0.75); + --bs-navbar-hover-color: rgba(255, 255, 255, 0.9); + --bs-navbar-disabled-color: rgba(255, 255, 255, 0.25); + --bs-navbar-active-color: #fff; + --bs-navbar-brand-color: #fff; + --bs-navbar-brand-hover-color: #fff; + background-color: var(--bs-dark); +} + +[data-bs-theme="dark"] .form-check-input { + background-color: var(--bs-secondary-bg); + border-color: var(--bs-border-color); +} \ No newline at end of file diff --git a/src/templates/base.html b/src/templates/base.html new file mode 100644 index 00000000..5e3ab30a --- /dev/null +++ b/src/templates/base.html @@ -0,0 +1,72 @@ + + + + + + SMLP Dashboard + + + + + + + +
+
+ {% block content %}{% endblock %} +
+
+ + + + + \ No newline at end of file diff --git a/src/templates/doe.html b/src/templates/doe.html new file mode 100644 index 00000000..a95b9f89 --- /dev/null +++ b/src/templates/doe.html @@ -0,0 +1,49 @@ +{% extends "base.html" %} +{% block content %} +

Design of Experiments (DOE)

+ +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ +
+
+{% endblock %} diff --git a/src/templates/exploration.html b/src/templates/exploration.html new file mode 100644 index 00000000..988b350d --- /dev/null +++ b/src/templates/exploration.html @@ -0,0 +1,138 @@ +{% extends "base.html" %} +{% block content %} +

Explore SMLP

+
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ +
+ + +
+ +
+ + +
+ +
+ +
+ + + + + +
+
+ +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ + +
+ + +
+ +
+ +
+ + + + + +
+
+ +
+ +
+ + + + + +
+ + + +
+ +
+ +{% endblock %} diff --git a/src/templates/index.html b/src/templates/index.html new file mode 100644 index 00000000..1958bcc0 --- /dev/null +++ b/src/templates/index.html @@ -0,0 +1,14 @@ +{% extends "base.html" %} +{% block content %} +
+

Welcome to SMLP Dashboard

+

A powerful web interface for training, predicting, exploring, and DOE with SMLP.

+ +
+ Train + Predict + Explore + DOE +
+
+{% endblock %} \ No newline at end of file diff --git a/src/templates/index.html:Zone.Identifier b/src/templates/index.html:Zone.Identifier new file mode 100644 index 00000000..e69de29b diff --git a/src/templates/predict.html b/src/templates/predict.html new file mode 100644 index 00000000..2a88e017 --- /dev/null +++ b/src/templates/predict.html @@ -0,0 +1,71 @@ +{% extends "base.html" %} +{% block content %} +

Predict with a Trained Model

+ +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ +
+ +
+{% endblock %} diff --git a/src/templates/results.html b/src/templates/results.html new file mode 100644 index 00000000..88001476 --- /dev/null +++ b/src/templates/results.html @@ -0,0 +1,74 @@ +{% extends "base.html" %} +{% block content %} +
+
+ +
+
+
+

Results

+
+
+
+
{{ output }}
+
+
+
+
+ +
+
+
+

Generated Plots

+
+
+ {% if plots %} +
+
+ {% for plot in plots %} +
+ + Generated Plot + +
+ {% endfor %} +
+
+ {% else %} +

No plots available.

+ {% endif %} +
+
+
+
+ + {% if use_h5 %} +
+
+
+
+

Model Visualization

+
+
+ +
+ {{ h5_plot_html | safe }} +
+
+
+
+
+ {% endif %} + +
+ +
+ +
+{% endblock %} diff --git a/src/templates/results.html:Zone.Identifier b/src/templates/results.html:Zone.Identifier new file mode 100644 index 00000000..e69de29b diff --git a/src/templates/train.html b/src/templates/train.html new file mode 100644 index 00000000..c37fe72f --- /dev/null +++ b/src/templates/train.html @@ -0,0 +1,103 @@ +{% extends "base.html" %} +{% block content %} +

Train a Model

+
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ +
+ + + + + +
+
+ +
+ + +
+ + +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ +
+ + + + + +
+
+ +
+ + +
+ +
+ +
+
+{% endblock %} diff --git a/src/test_app.py b/src/test_app.py new file mode 100644 index 00000000..b449ec78 --- /dev/null +++ b/src/test_app.py @@ -0,0 +1,226 @@ +import os +import unittest +import tempfile +import subprocess +from unittest.mock import patch +from app import app, call_smlp_api + +class AppTestCase(unittest.TestCase): + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + self.upload_folder = tempfile.mkdtemp() + app.config['UPLOAD_FOLDER'] = self.upload_folder + + def tearDown(self): + for file in os.listdir(self.upload_folder): + os.remove(os.path.join(self.upload_folder, file)) + os.rmdir(self.upload_folder) + + def test_home(self): + response = self.client.get('/') + self.assertEqual(response.status_code, 200) + + def test_explore_get(self): + response = self.client.get('/explore') + self.assertEqual(response.status_code, 200) + + def test_explore_invalid_mode(self): + response = self.client.post('/explore', data={'explore_mode': 'invalid_mode'}) + self.assertEqual(response.status_code, 400) + self.assertIn(b"Error: Invalid mode selected", response.data) + + def test_explore_missing_files(self): + response = self.client.post('/explore', data={'explore_mode': 'certify'}) + self.assertEqual(response.status_code, 400) + self.assertIn(b"Error: Missing required dataset or spec file", response.data) + + def test_explore_full_form_submission(self): + data_file_path = os.path.join(self.upload_folder, "data.csv") + spec_file_path = os.path.join(self.upload_folder, "spec.txt") + + with open(data_file_path, 'w') as f: + f.write("sample,data,content") + with open(spec_file_path, 'w') as f: + f.write('''{ + "version": "1.2", + "spec": { + "targets": ["y1"], + "constraints": [] + } + }''') + with self.client.session_transaction() as session: + output = session.get('output', '') + self.assertNotIn('Exception', output) + self.assertNotIn('Traceback', output) + + with open(data_file_path, 'rb') as data_file, open(spec_file_path, 'rb') as spec_file: + response = self.client.post('/explore', data={ + 'explore_mode': 'optimize', + 'data_file': (data_file, 'data.csv'), + 'spec_file': (spec_file, 'spec.txt'), + 'out_dir_val': './results', + 'pref_val': 'Test113', + 'pareto': 't', + 'resp': 'y', + 'feat': 'x1,x2', + 'model_expr': 'dt', + 'dt_sklearn_max_depth': '5', + 'mrmr_pred': '5', + 'epsilon': '0.01', + 'delta_rel': '0.05', + 'save_model_config': 'yes', + 'plots': 'yes', + 'log_time': 'yes', + 'seed_val': '42', + 'objv_names': 'objv1,objv2', + 'objv_exprs': 'y1>7 and y2<3', + 'additional_command': '-seed 10' + }, content_type='multipart/form-data') + + self.assertEqual(response.status_code, 302) + self.assertTrue(response.headers['Location'].endswith('/results')) + + def test_train_missing_file(self): + response = self.client.post('/train', data={}) + self.assertEqual(response.status_code, 400) + self.assertIn(b"Error: Missing required dataset or spec file", response.data) + + def test_train_valid_submission(self): + data_file_path = os.path.join(self.upload_folder, "data.csv") + with open(data_file_path, 'w') as f: + f.write("sample,data,content") + + with open(data_file_path, 'rb') as data_file: + response = self.client.post('/train', data={ + 'data_file': (data_file, 'data.csv'), + 'out_dir_val': './results' + }, content_type='multipart/form-data') + + self.assertEqual(response.status_code, 302) + self.assertTrue(response.headers['Location'].endswith('/results')) + + def test_call_smlp_api(self): + result = call_smlp_api(['echo', 'hello']) + self.assertEqual(result.strip(), 'hello') + + def test_results_route(self): + plot_dir = os.path.abspath("../regr_smlp/code/images/results") + os.makedirs(plot_dir, exist_ok=True) + plot_path = os.path.join(plot_dir, "dummy_plot.png") + with open(plot_path, 'wb') as f: + f.write(b'\x89PNG\r\n\x1a\n') + + with self.client.session_transaction() as sess: + sess['output'] = 'Test output message' + sess['use_h5'] = False + + response = self.client.get('/results') + self.assertEqual(response.status_code, 200) + self.assertIn(b'Test output message', response.data) + self.assertIn(b'dummy_plot.png', response.data) + + os.remove(plot_path) + + def test_serve_plot_route(self): + plot_dir = os.path.abspath("../regr_smlp/code/images/results") + os.makedirs(plot_dir, exist_ok=True) + filename = "test_image.png" + filepath = os.path.join(plot_dir, filename) + + with open(filepath, 'wb') as f: + f.write(b'\x89PNG\r\n\x1a\n') + + with open(filepath, 'rb') as f: + response = self.client.get(f'/plots/{filename}') + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content_type, 'image/png') + + os.remove(filepath) + + @patch('app.get_h5_plot') + def test_results_route_with_h5(self, mock_get_h5_plot): + mock_get_h5_plot.return_value = "
Mock H5 Plot
" + + with self.client.session_transaction() as sess: + sess['output'] = 'Output with H5' + sess['use_h5'] = True + + response = self.client.get('/results') + self.assertEqual(response.status_code, 200) + self.assertIn(b'Output with H5', response.data) + self.assertIn(b'Mock H5 Plot', response.data) + + def test_clear_old_plots(self): + from app import clear_old_plots, PLOT_SAVE_DIR + os.makedirs(PLOT_SAVE_DIR, exist_ok=True) + dummy_plot = os.path.join(PLOT_SAVE_DIR, 'dummy.png') + with open(dummy_plot, 'wb') as f: + f.write(b'\x89PNG\r\n\x1a\n') + + self.assertTrue(os.path.exists(dummy_plot)) + clear_old_plots() + self.assertFalse(os.path.exists(dummy_plot)) + + def test_doe_missing_file(self): + response = self.client.post('/doe', data={}) + self.assertEqual(response.status_code, 400) + self.assertIn(b"Error: Missing DOE spec file", response.data) + + def test_doe_valid_submission(self): + doe_file_path = os.path.join(self.upload_folder, "doe_spec.json") + with open(doe_file_path, 'w') as f: + f.write('''{ + "version": "1.2", + "spec": { + "design": "lhs", + "samples": 10 + } + }''') + + with open(doe_file_path, 'rb') as doe_file: + response = self.client.post('/doe', data={ + 'doe_spec_file': (doe_file, 'doe_spec.json'), + 'out_dir_val': './results', + 'pref_val': 'TestDOE', + 'doe_algo': 'lhs', + 'log_time': 'yes', + 'additional_command': '-seed 999' + }, content_type='multipart/form-data') + + self.assertEqual(response.status_code, 302) + self.assertTrue(response.headers['Location'].endswith('/results')) + + def test_predict_missing_files(self): + response = self.client.post('/predict', data={}) + self.assertEqual(response.status_code, 400) + self.assertIn(b"Error: Both model file and new data file are required", response.data) + + def test_predict_valid_submission(self): + model_file_path = os.path.join(self.upload_folder, "model.h5") + new_data_file_path = os.path.join(self.upload_folder, "new_data.csv") + + with open(model_file_path, 'w') as f: + f.write("mock model content") + with open(new_data_file_path, 'w') as f: + f.write("mock,new,data") + + with open(model_file_path, 'rb') as model_file, open(new_data_file_path, 'rb') as new_data_file: + response = self.client.post('/predict', data={ + 'model_file': (model_file, 'model.h5'), + 'new_data_file': (new_data_file, 'new_data.csv'), + 'out_dir_val': './results', + 'pref_val': 'PredictRun', + 'log_time': 'yes', + 'plots': 'yes', + 'save_model': 't', + 'model_name': 'my_model', + 'seed_val': '123', + 'additional_command': '-seed 999' + }, content_type='multipart/form-data') + + self.assertEqual(response.status_code, 302) + self.assertTrue(response.headers['Location'].endswith('/results')) + +if __name__ == '__main__': + unittest.main() diff --git a/src/test_visual.py b/src/test_visual.py new file mode 100644 index 00000000..d1327085 --- /dev/null +++ b/src/test_visual.py @@ -0,0 +1,102 @@ +import os +import unittest +import tempfile +import numpy as np +import pandas as pd +from unittest.mock import patch, MagicMock + +from visualize_h5 import ( + get_latest_file, + get_latest_h5_file, + get_latest_training_csv, + get_feature_estimates, + load_model_and_generate_predictions, + get_h5_plot, + H5_DIRECTORY, + CSV_DIRECTORY +) + +class TestVisualizeH5(unittest.TestCase): + def setUp(self): + self.test_dir = tempfile.mkdtemp() + + def tearDown(self): + for f in os.listdir(self.test_dir): + os.remove(os.path.join(self.test_dir, f)) + os.rmdir(self.test_dir) + + def test_get_latest_file_returns_none_if_no_files(self): + result = get_latest_file(self.test_dir, '*.csv') + self.assertIsNone(result) + + def test_get_latest_file_returns_latest(self): + file1 = os.path.join(self.test_dir, 'file1.csv') + file2 = os.path.join(self.test_dir, 'file2.csv') + with open(file1, 'w') as f: + f.write('test1') + with open(file2, 'w') as f: + f.write('test2') + os.utime(file1, (1, 1)) + os.utime(file2, (2, 2)) + + result = get_latest_file(self.test_dir, '*.csv') + self.assertEqual(result, file2) + + @patch('visualize_h5.get_latest_file') + def test_get_latest_h5_file(self, mock_get): + mock_get.return_value = '/path/to/latest.h5' + self.assertEqual(get_latest_h5_file(), '/path/to/latest.h5') + + @patch('visualize_h5.get_latest_file') + def test_get_latest_training_csv(self, mock_get): + mock_get.return_value = '/path/to/latest.csv' + self.assertEqual(get_latest_training_csv(), '/path/to/latest.csv') + + @patch('visualize_h5.get_latest_training_csv') + def test_get_feature_estimates_returns_means(self, mock_csv): + df = pd.DataFrame({"y1": [1, 2, 3], "y2": [4, 5, 6]}) + temp_csv = os.path.join(self.test_dir, 'test.csv') + df.to_csv(temp_csv, index=False) + mock_csv.return_value = temp_csv + + y1, y2 = get_feature_estimates() + self.assertEqual(y1, 2.0) + self.assertEqual(y2, 5.0) + + @patch('visualize_h5.tf.keras.models.load_model') + @patch('visualize_h5.get_latest_h5_file') + @patch('visualize_h5.get_feature_estimates') + def test_load_model_and_generate_predictions(self, mock_est, mock_h5, mock_load): + mock_h5.return_value = '/mock/model.h5' + mock_est.return_value = (5.0, 5.0) + + dummy_model = MagicMock() + dummy_model.predict.return_value = np.random.rand(900, 1) + mock_load.return_value = dummy_model + + y1, y2, z = load_model_and_generate_predictions() + self.assertIsNotNone(y1) + self.assertIsNotNone(y2) + self.assertIsNotNone(z) + self.assertEqual(y1.shape, y2.shape) + self.assertEqual(y1.shape, z.shape) + + @patch('visualize_h5.load_model_and_generate_predictions') + def test_get_h5_plot_returns_html(self, mock_load): + y1 = y2 = np.linspace(0, 1, 10) + Y1, Y2 = np.meshgrid(y1, y2) + Z = np.random.rand(10, 10) + mock_load.return_value = (Y1, Y2, Z) + + html = get_h5_plot() + self.assertIn(' 1: + predictions = predictions[:, 0] + + + Z = predictions.reshape(Y1.shape) + return Y1, Y2, Z + except Exception as e: + print(f"❌ Error loading model or generating predictions: {str(e)}") + return None, None, None + +def get_h5_plot(): + """Generate an interactive 3D surface plot from the model predictions.""" + y1, y2, z = load_model_and_generate_predictions() + + if y1 is None or y2 is None or z is None: + return "

⚠ Error: No valid model output found.

" + + pio.renderers.default = 'browser' + + fig = go.Figure() + + fig.add_trace(go.Surface( + x=y1, y=y2, z=z + )) + + fig.update_layout( + title="Interactive 3D Model Output", + scene=dict(xaxis_title="y1", yaxis_title="y2", zaxis_title="Model Prediction (Z)"), + margin=dict(l=0, r=0, b=0, t=40) + ) + + return pio.to_html(fig, full_html=False)