Fine-tune DistilBERT for prompt injection detection and export to ONNX for production inference.
This pipeline trains a binary classifier on top of distilbert-base-uncased to detect prompt injection attacks. The model is exported to ONNX format with INT8 quantization for efficient inference in the Go-based PIF detection engine.
train.py ──▶ Fine-tuned Model ──▶ export_onnx.py ──▶ ONNX + INT8 ──▶ evaluate.py
(PyTorch) (Optimum) (~65 MB) (Benchmark)
- Python 3.10+
- CUDA GPU recommended (training works on CPU but takes longer)
cd ml/
pip install -r requirements.txtpython train.pyTraining data: deepset/prompt-injections from Hugging Face Hub.
Test data: PIF benchmark dataset (benchmarks/dataset/) — 210 samples (100 benign + 110 injection) curated from real-world attacks.
Hyperparameters:
| Parameter | Value |
|---|---|
| Base model | distilbert-base-uncased |
| Max sequence length | 256 tokens |
| Epochs | 5 |
| Batch size | 16 |
| Learning rate | 2e-5 |
| Weight decay | 0.01 |
| Metric for best model | F1 score |
Output:
output/
├── best/ # Best model checkpoint (PyTorch)
│ ├── model.safetensors
│ ├── config.json
│ ├── tokenizer.json
│ └── ...
└── results/
├── metrics.json # Training & evaluation metrics
└── logs/ # TensorBoard-compatible logs
After training, export the model to ONNX with INT8 quantization:
python export_onnx.py [--model-dir ./output/best] [--output-dir ./output/onnx]This script:
- Exports the PyTorch model to ONNX format using Hugging Face Optimum
- Applies INT8 dynamic quantization (~50% size reduction)
- Validates the exported model against test cases
Output:
output/onnx/
├── model.onnx # Full ONNX model (~130 MB)
└── quantized/
├── model_quantized.onnx # INT8 quantized (~65 MB)
├── config.json
├── tokenizer.json
└── tokenizer_config.json
Run standalone evaluation against the PIF benchmark dataset:
python evaluate.py [--model-dir ./output/onnx/quantized]Reports:
- Accuracy, F1, Precision, Recall
- Confusion matrix (TN, FP, FN, TP)
- Per-category detection rates (10 attack categories)
- False positive analysis
- Missed detection analysis
PIF targets: Detection rate ≥ 80%, False positive rate ≤ 10%.
After training and export, upload the quantized model:
huggingface-cli login
huggingface-cli upload ogulcanaydogan/pif-distilbert-injection-classifier output/onnx/quantized/The Go ML detector downloads the model from the Hub at runtime:
# In PIF
pif scan --model ogulcanaydogan/pif-distilbert-injection-classifier "test prompt"Input Text
│
▼
WordPiece Tokenizer (max 256 tokens)
│
▼
DistilBERT (6 layers, 768 hidden, 12 heads)
│
▼
Classification Head (768 → 2)
│
▼
Softmax → [BENIGN, INJECTION]
Label mapping:
0→BENIGN— safe, legitimate prompts1→INJECTION— prompt injection attacks
The ONNX model is loaded by the Go MLDetector (behind the ml build tag) using ONNX Runtime. The detector maps model confidence to PIF severity levels:
| Confidence | Severity |
|---|---|
| ≥ 0.95 | Critical |
| ≥ 0.90 | High |
| ≥ 0.85 | Medium |
| ≥ 0.75 | Low |
| < 0.75 | Info (below threshold) |
In the ensemble, the ML detector runs alongside the regex detector with configurable weights (default: regex 0.6, ML 0.4).