Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"train_loss": None,
"val_metrics": None, # dict after each epoch (accuracy, precision, …)
"test_metrics": None, # dict after training finishes
"run_id": None,
"error": None,
}
_lock = threading.Lock()
Expand All @@ -35,7 +36,7 @@
def _reset_state():
_state.update(
status="idle", epoch=0, total_epochs=0,
train_loss=None, val_metrics=None, test_metrics=None, error=None,
train_loss=None, val_metrics=None, test_metrics=None, run_id=None, error=None,
)


Expand All @@ -62,7 +63,7 @@ def _train_worker(dataset_root, backbone, epochs, batch_size, max_samples, out_d
_state["status"] = "training"
_state["total_epochs"] = epochs

test_metrics = run_train(
test_metrics, run_id = run_train(
dataset_root=dataset_root,
backbone=backbone,
epochs=epochs,
Expand All @@ -75,6 +76,7 @@ def _train_worker(dataset_root, backbone, epochs, batch_size, max_samples, out_d

with _lock:
_state["status"] = "complete"
_state["run_id"] = run_id
if test_metrics is not None:
_state["test_metrics"] = {
k: v for k, v in test_metrics.items()
Expand All @@ -87,9 +89,11 @@ def _train_worker(dataset_root, backbone, epochs, batch_size, max_samples, out_d
test_metrics["confusion_matrix"].tolist()
)
except Exception:
tb = traceback.format_exc()
print(f"\n[Training ERROR]\n{tb}")
with _lock:
_state["status"] = "error"
_state["error"] = traceback.format_exc()
_state["error"] = tb


# ── Routes ────────────────────────────────────────────────────────────────
Expand Down Expand Up @@ -144,6 +148,7 @@ def api_status():
"train_loss": _state["train_loss"],
"val_metrics": _state["val_metrics"],
"test_metrics": _state["test_metrics"],
"run_id": _state["run_id"],
"error": _state["error"],
}
return jsonify(payload)
Expand Down
19 changes: 13 additions & 6 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import time
import argparse
import json
from datetime import datetime

import numpy as np
import torch
Expand Down Expand Up @@ -134,6 +135,9 @@ def train(
max_samples: int = 0,
progress_callback=None,
):
run_id = datetime.now().strftime("%Y%m%d_%H%M%S")
print(f"Run ID: {run_id}")

device = _get_device()
print(f"Using device: {device}")

Expand Down Expand Up @@ -262,7 +266,7 @@ def train(

if val_auc > best_val_auc:
best_val_auc = val_auc
ckpt_path = Path(out_dir) / f"best_{backbone}.pt"
ckpt_path = Path(out_dir) / f"best_{backbone}_{run_id}.pt"
torch.save(
{
"model_state": model.state_dict(),
Expand All @@ -281,7 +285,7 @@ def train(
print("\n--- Post-training evaluation on TEST splits ---")

# Reload best checkpoint if it exists
ckpt_path = Path(out_dir) / f"best_{backbone}.pt"
ckpt_path = Path(out_dir) / f"best_{backbone}_{run_id}.pt"
if ckpt_path.exists():
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
model.load_state_dict(ckpt["model_state"])
Expand All @@ -299,7 +303,7 @@ def train(
_print_metrics(test_metrics, header="TEST SET RESULTS")

# Save metrics to JSON
metrics_path = Path(out_dir) / "test_metrics.json"
metrics_path = Path(out_dir) / f"test_metrics_{run_id}.json"
serialisable = {
k: v for k, v in test_metrics.items()
if k not in ("confusion_matrix", "y_true", "y_scores")
Expand All @@ -319,16 +323,19 @@ def train(
plot_roc_curve,
)

plot_training_history(train_losses, val_losses, val_aucs)
plot_training_history(train_losses, val_losses, val_aucs, tag=run_id)

if test_metrics is not None:
plot_confusion_matrix(
test_metrics["confusion_matrix"],
class_names=["Fake", "Real"],
tag=run_id,
)
plot_roc_curve(test_metrics["y_true"], test_metrics["y_scores"])
plot_roc_curve(test_metrics["y_true"], test_metrics["y_scores"], tag=run_id)

return test_metrics
if test_metrics is not None:
test_metrics["run_id"] = run_id
return test_metrics, run_id


# ---------------------------------------------------------------------------
Expand Down
12 changes: 9 additions & 3 deletions src/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def plot_training_history(
val_losses: Optional[List[float]] = None,
val_aucs: Optional[List[float]] = None,
save: bool = True,
tag: Optional[str] = None,
) -> None:
"""Plot training / validation loss and (optionally) validation AUC curves."""
ncols = 2 if val_aucs else 1
Expand Down Expand Up @@ -59,7 +60,8 @@ def plot_training_history(

if save:
_ensure_figures_dir()
path = os.path.join(FIGURES_DIR, "training_history.png")
fname = f"training_history_{tag}.png" if tag else "training_history.png"
path = os.path.join(FIGURES_DIR, fname)
plt.savefig(path, dpi=150)
print(f"Saved training history plot to {path}")
plt.close(fig)
Expand All @@ -73,6 +75,7 @@ def plot_confusion_matrix(
cm: np.ndarray,
class_names: Optional[List[str]] = None,
save: bool = True,
tag: Optional[str] = None,
) -> None:
"""Plot a 2x2 confusion matrix heatmap."""
if class_names is None:
Expand Down Expand Up @@ -105,7 +108,8 @@ def plot_confusion_matrix(

if save:
_ensure_figures_dir()
path = os.path.join(FIGURES_DIR, "confusion_matrix.png")
fname = f"confusion_matrix_{tag}.png" if tag else "confusion_matrix.png"
path = os.path.join(FIGURES_DIR, fname)
plt.savefig(path, dpi=150)
print(f"Saved confusion matrix plot to {path}")
plt.close(fig)
Expand All @@ -119,6 +123,7 @@ def plot_roc_curve(
y_true: np.ndarray,
y_scores: np.ndarray,
save: bool = True,
tag: Optional[str] = None,
) -> None:
"""Plot ROC curve with AUC annotation."""
fpr, tpr, _ = roc_curve(y_true, y_scores)
Expand All @@ -137,7 +142,8 @@ def plot_roc_curve(

if save:
_ensure_figures_dir()
path = os.path.join(FIGURES_DIR, "roc_curve.png")
fname = f"roc_curve_{tag}.png" if tag else "roc_curve.png"
path = os.path.join(FIGURES_DIR, fname)
plt.savefig(path, dpi=150)
print(f"Saved ROC curve plot to {path}")
plt.close(fig)
19 changes: 11 additions & 8 deletions templates/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<title>Face the Truth — AI Face Detector</title>
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.3/dist/css/bootstrap.min.css"
rel="stylesheet"
integrity="sha384-QWTKZyjpPEjISv5WaRU9OFeRpok6YcnS/hYkl6PFlCuTBqLDO6FSDMO3TMggMHYb"
integrity="sha384-QWTKZyjpPEjISv5WaRU9OFeRpok6YctnYmDr5pNlyT2bRjXh0JMhjY6hW+ALEwIH"
crossorigin="anonymous">
<link rel="preconnect" href="https://fonts.googleapis.com">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
Expand Down Expand Up @@ -451,7 +451,7 @@ <h6 class="card-title mb-2" style="font-size:.85rem">ROC Curve</h6>
clearInterval(pollTimer);
showStatus('success', '\u2705', 'Training complete!');
trainBtn.disabled = false;
renderResults(data.test_metrics);
renderResults(data.test_metrics, data.run_id);
}
else if (data.status === 'error') {
clearInterval(pollTimer);
Expand All @@ -461,9 +461,16 @@ <h6 class="card-title mb-2" style="font-size:.85rem">ROC Curve</h6>
} catch (e) { /* network hiccup */ }
}

function renderResults(m) {
function renderResults(m, runId) {
document.getElementById('results-section').style.display = 'block';

const ts = Date.now();
const tag = runId || ts;
document.getElementById('plotHistory').src = `/figures/training_history_${tag}.png?t=${ts}`;
document.getElementById('plotCM').src = `/figures/confusion_matrix_${tag}.png?t=${ts}`;
document.getElementById('plotROC').src = `/figures/roc_curve_${tag}.png?t=${ts}`;

if (!m) return;
document.getElementById('results-section').style.display = '';

const metrics = [
{ name: 'Accuracy', val: m.accuracy, desc: 'Overall percentage of correct predictions (real + fake).' },
Expand Down Expand Up @@ -505,10 +512,6 @@ <h6 class="card-title mb-2" style="font-size:.85rem">ROC Curve</h6>
document.getElementById('cmTable').innerHTML = html;
}

const ts = Date.now();
document.getElementById('plotHistory').src = `/figures/training_history.png?t=${ts}`;
document.getElementById('plotCM').src = `/figures/confusion_matrix.png?t=${ts}`;
document.getElementById('plotROC').src = `/figures/roc_curve.png?t=${ts}`;
}
</script>
</body>
Expand Down