Skip to content

ZJUICSR/ControlNET

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

🛡️ ControlNET: A Firewall for RAG-based LLM Systems

This repository contains the official implementation of "CONTROLNET: A Firewall for RAG-based LLM System". ControlNET is a comprehensive defense framework designed to detect and mitigate malicious activities and vulnerabilities within Retrieval-Augmented Generation (RAG) pipelines.

By monitoring the internal activations of Large Language Models (LLMs), ControlNET trains lightweight linear detectors to identify adversarial intents and implements harmlessness alignments via LoRA to ensure safe RAG interactions.

🔬 Supported Threat Models (Risk Types)

ControlNET evaluates and defends against five primary security risks in RAG systems:

  • Reconnaissance: Probing the system for structural or systemic vulnerabilities.
  • Exfiltration: Attempting to extract sensitive context or system prompts.
  • Unauthorized Access: Bypassing access controls to retrieve restricted domain data (e.g., financial, employee, or case data).
  • Knowledge Poisoning: Injecting malicious context to alter the LLM's factual generation.
  • Hijacking: Forcing the model to execute adversarial instructions overriding the user's prompt.

🗂️ Evaluated Datasets

  • FinQA: Financial question-answering dataset.
  • HotpotQA: Multi-hop reasoning QA dataset.
  • MS MARCO: Microsoft search ranking dataset.
  • MedicalRAG: Medical retrieval-augmented generation dataset.

🏗️ System Architecture

ControlNet/
├── activations/           # Activation extraction and FAISS indexing
│   └── generate.py        # Main script for generating model hidden states
├── config/                # Centralized configurations
│   └── models.py          # Model initialization and dataset mapping
├── models/                # Base architecture definitions
│   └── model.py           # Core model wrapper classes
├── ragsys/                # RAG system utilities
│   └── load.py            # Corpus loading and mean pooling utilities
├── training/              # Defense training pipelines
│   ├── dataset.py         # DataLoader for activation states
│   ├── train_linear_model.py # Linear probing for anomaly detection
│   └── Lora.py            # Harmlessness implementation via Low-Rank Adaptation
├── utils/                 # Utilities and processors
│   └── activations.py     # Risk-specific activation processors (Base, Reconnaissance, Hijacking, etc.)
└── CONFIG.py              # Global risk type configuration

⚙️ Installation and Environment Setup

We recommend using Conda to manage your environment to ensure reproducibility.

# Clone the repository
git clone https://github.com/ZJUICSR/ControlNET.git
cd ControlNet

# Create and activate the environment
conda env create -f environment.yml
conda activate controlnet

Key Dependencies:

  • torch >= 1.12.0
  • transformers >= 4.20.0
  • faiss-cpu (or faiss-gpu) >= 1.7.2
  • scikit-learn >= 1.1.0
  • sentence-transformers

🚀 Pipeline Usage

The ControlNET defense pipeline operates in three distinct phases: Activation Extraction, Detector Training, and Mitigation (Harmlessness Implementation).

Phase 1: Configuration

First, set the target threat model you wish to analyze. Edit ControlNet/CONFIG.py to set the current_risk:

# ControlNet/CONFIG.py
sub_risk = [
    'Reconnaissance',
    'Exfiltration', 
    'Unauthorized_Access',
    'Knowledge',
    'Hijacking'
]

# Set the active risk type (e.g., index 0 for Reconnaissance)
current_risk = sub_risk[0] 

Phase 2: Activation Generation & Indexing

Run the generation script to construct the FAISS vector database (using bert-base-uncased) and extract the LLM's hidden state activations for both clean and adversarial inputs.

The script dynamically processes subsets (e.g., clean vs poisoned, or access tiers like employee/financial for Unauthorized Access) based on the active risk.

cd ControlNet/activations
python generate.py \
    --model_name llama3_8b \
    --model_path /path/to/local/model_weights

Phase 3: Train the Anomaly Detector

Train a linear classifier on the extracted activations to detect whether an incoming query exhibits malicious intent.

cd ../training
python train_linear_model.py

Phase 4: Harmlessness Implementation

Once malicious queries are detected, apply the mitigation strategy using LoRA to steer the model towards harmless refusal or safe generation.

python Lora.py

🧠 Supported Models

The current implementation provides out-of-the-box support for the following architectures:

  • LLaMA 3 (8B Instruct)
  • Mistral (7B)
  • Vicuna (7B)

To add a custom model, append its configuration to ControlNet/config/models.py.


📖 Citation

@article{yao2025control,
  title={ControlNET: A Firewall for RAG-based LLM System},
  author={Hongwei Yao, Haoran Shi, Shuo Shao, Yidou Chen, Cong Wang, Zhan Qin},
  journal={arXiv preprint arXiv:2504.09593},
  year={2025}
}

About

Code for paper "CONTROLNET: A Firewall for RAG-based LLM System"

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages