Skip to content

Commit 0ce4bed

Browse files
author
ansschh
committed
Initial commit of Shielded RecRL codebase with safety audits and result aggregation
0 parents  commit 0ce4bed

81 files changed

Lines changed: 4729 additions & 0 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/safety.yml

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
name: safety_check
2+
on: [push, pull_request]
3+
jobs:
4+
audit:
5+
runs-on: ubuntu-latest
6+
steps:
7+
- uses: actions/checkout@v4
8+
- name: parse thresholds
9+
id: th
10+
run: |
11+
python - <<'PY'
12+
import yaml, json, os
13+
th=yaml.safe_load(open('docs/policy_thresholds.yaml'))
14+
print('::set-output name=th::'+json.dumps(th))
15+
PY
16+
- name: check reports
17+
run: |
18+
python - <<'PY'
19+
import json, glob, os, sys, yaml
20+
th = yaml.safe_load(open('docs/policy_thresholds.yaml'))
21+
for rep in glob.glob('docs/safety_report_*.json'):
22+
r=json.load(open(rep))
23+
assert r['tox']['mean'] < th['tox_mean_max'], rep+" tox mean"
24+
assert r['tox']['p95'] < th['tox_p95_max'], rep+" tox p95"
25+
assert abs(r['pop']['delta']) < th['gini_delta_max'], rep+" gini"
26+
if 'parity' in r and r['parity']:
27+
gap=abs(r['parity']['male_rate']-r['parity']['female_rate'])
28+
assert gap < th['gender_gap_max'], rep+" gender gap"
29+
assert r['privacy'] < th['privacy_leak_max'], rep+" leakage"
30+
print("All safety checks passed")
31+
PY
32+
- name: verify main results
33+
run: |
34+
test -f docs/main_results.csv
35+
python - <<'PY'
36+
import pandas as pd
37+
df = pd.read_csv("docs/main_results.csv")
38+
assert df.shape==(3,5)
39+
assert df['ctr_sim'].min()>0, "CTR empty"
40+
print("Results table OK")
41+
PY

.gitignore

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Python
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
*.so
6+
.Python
7+
build/
8+
develop-eggs/
9+
dist/
10+
downloads/
11+
eggs/
12+
.eggs/
13+
lib/
14+
lib64/
15+
parts/
16+
sdist/
17+
var/
18+
wheels/
19+
*.egg-info/
20+
.installed.cfg
21+
*.egg
22+
23+
# Jupyter Notebook
24+
.ipynb_checkpoints
25+
26+
# Virtual Environment
27+
venv/
28+
env/
29+
ENV/
30+
31+
# Project specific
32+
checkpoints/
33+
logs/
34+
data/*.json
35+
data/*.csv
36+
data/*.zip
37+
data/*.gz
38+
39+
# W&B
40+
wandb/
41+
42+
# IDE
43+
.idea/
44+
.vscode/
45+
*.swp
46+
*.swo
47+
48+
# OS
49+
.DS_Store
50+
Thumbs.db

README.md

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Shielded RecRL
2+
3+
This repository contains the implementation of Shielded RecRL, a method for adding chat-style explanations to recommender systems without affecting the underlying ranking model.
4+
5+
## Project Overview
6+
7+
Shielded RecRL uses a two-tower architecture:
8+
- A frozen ranking model (collaborative filtering)
9+
- A trainable language model that generates explanations
10+
11+
The key innovation is the gradient projection technique that prevents the explanation model from affecting the ranking model's performance.
12+
13+
## Setup Instructions
14+
15+
### Local Setup (Any OS)
16+
17+
1. Clone this repository:
18+
```bash
19+
git clone https://github.com/your_username/shielded-recrl.git
20+
cd shielded-recrl
21+
```
22+
23+
2. Edit `setup_local.sh` to update your GitHub username, then run:
24+
```bash
25+
bash setup_local.sh
26+
```
27+
28+
### RunPod Setup (Remote GPU)
29+
30+
1. Launch a RunPod instance with:
31+
- Runtime: PyTorch 2.3 | Python 3.10 | CUDA 12.2
32+
- GPU: NVIDIA A100 80GB or 2× RTX 4090 24GB
33+
- Volume: ≥ 400GB
34+
35+
2. SSH into your RunPod instance:
36+
```bash
37+
ssh -p YOUR_PORT runpod@YOUR_POD_ID.connect.runpod.io
38+
```
39+
40+
3. Edit `setup_runpod.sh` to update your GitHub username, then run:
41+
```bash
42+
bash setup_runpod.sh
43+
```
44+
45+
4. Verify the setup:
46+
```bash
47+
python gpu_test.py
48+
```
49+
50+
## Project Structure
51+
52+
```
53+
├── code
54+
│ ├── dataset/ # Dataset preprocessing
55+
│ ├── ranker/ # SASRec implementation
56+
│ ├── explainer/ # LLM with LoRA
57+
│ ├── projection/ # Gradient projection
58+
│ ├── trainer/ # Shielded PPO
59+
│ └── eval/ # Evaluation metrics
60+
├── data # Datasets
61+
├── checkpoints # Model checkpoints
62+
├── logs # Training logs
63+
├── experiments # Experiment configurations
64+
├── docs # Documentation
65+
└── docker # Docker configuration
66+
```
67+
68+
## Workflow
69+
70+
1. Edit code on your local machine
71+
2. Commit and push changes to GitHub
72+
3. Pull changes on RunPod and execute experiments
73+
4. Results are logged to W&B and saved to the persistent volume
74+
75+
## License
76+
77+
[Add your license information here]

code/audit/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Safety, bias, and toxicity audit package for Shielded RecRL."""

code/audit/bias.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import pandas as pd, numpy as np, json, pathlib, torch
2+
3+
def gini(array):
4+
"""Compute Gini coefficient."""
5+
array = np.array(array) + 1e-9
6+
array = np.sort(array)
7+
n = len(array)
8+
return (2*np.arange(1,n+1)-n-1).dot(array) / (n*array.sum())
9+
10+
def popularity_shift(ranker_ckpt, lora_ckpt):
11+
base = torch.load(ranker_ckpt, map_location='cpu')
12+
items, counts = np.unique(base["item_emb.weight"].argmax(1), return_counts=True)
13+
gini_base = gini(counts)
14+
15+
diff = torch.load(lora_ckpt, map_location='cpu')
16+
shift = diff["base_model.model.lm_head.weight"].abs().sum(1)
17+
gini_new = gini(shift.numpy())
18+
return {"gini_base": float(gini_base), "gini_new": float(gini_new),
19+
"delta": float(gini_new-gini_base)}
20+
21+
def gender_parity(rec_file, user_gender_csv):
22+
"""MovieLens only."""
23+
rec = pd.read_csv(rec_file) # cols: user,item
24+
dm = pd.read_csv(user_gender_csv) # cols: user,gender
25+
merged = rec.merge(dm, on="user")
26+
clicks = merged.groupby("gender").size()
27+
rate = clicks / clicks.sum()
28+
return {"male_rate": rate.get('M',0), "female_rate": rate.get('F',0)}

code/audit/generate_pdf_summary.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
#!/usr/bin/env python
2+
"""
3+
Generate a PDF summary of safety audit reports for Shielded RecRL.
4+
5+
This script reads all safety report JSON files and creates a PDF summary.
6+
7+
Usage:
8+
python generate_pdf_summary.py
9+
"""
10+
import pathlib
11+
import os
12+
import json
13+
import sys
14+
15+
def generate_summary_pdf():
16+
try:
17+
# Check for reportlab
18+
from reportlab.lib.pagesizes import letter
19+
from reportlab.pdfgen import canvas
20+
from reportlab.lib import colors
21+
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
22+
from reportlab.platypus import Paragraph, Spacer
23+
except ImportError:
24+
print("Error: reportlab package not installed. Install with: pip install reportlab")
25+
return False
26+
27+
root = pathlib.Path(os.getenv("PROJ", "."))
28+
reports_path = root / "docs"
29+
output_file = reports_path / "safety_summary.pdf"
30+
31+
# Collect all report files
32+
report_files = list(reports_path.glob("safety_report_*.json"))
33+
if not report_files:
34+
print("Error: No safety report files found in", reports_path)
35+
return False
36+
37+
# Load policy thresholds
38+
try:
39+
thresholds_file = reports_path / "policy_thresholds.yaml"
40+
if thresholds_file.exists():
41+
import yaml
42+
with open(thresholds_file) as f:
43+
thresholds = yaml.safe_load(f)
44+
else:
45+
thresholds = None
46+
except Exception as e:
47+
print(f"Warning: Could not load thresholds: {e}")
48+
thresholds = None
49+
50+
# Create the PDF
51+
c = canvas.Canvas(str(output_file), pagesize=letter)
52+
width, height = letter
53+
54+
# Title
55+
c.setFont("Helvetica-Bold", 16)
56+
c.drawString(50, height - 50, "Shielded RecRL: Safety Audit Summary")
57+
c.setFont("Helvetica", 10)
58+
c.drawString(50, height - 70, f"Generated on {time.strftime('%Y-%m-%d %H:%M:%S')}")
59+
60+
# Header line
61+
c.line(50, height - 80, width - 50, height - 80)
62+
63+
# Policy thresholds section if available
64+
y_pos = height - 100
65+
if thresholds:
66+
c.setFont("Helvetica-Bold", 12)
67+
c.drawString(50, y_pos, "Policy Thresholds:")
68+
y_pos -= 20
69+
c.setFont("Helvetica", 10)
70+
for key, value in thresholds.items():
71+
c.drawString(70, y_pos, f"{key}: {value}")
72+
y_pos -= 15
73+
y_pos -= 10
74+
75+
# Report summaries
76+
c.setFont("Helvetica-Bold", 12)
77+
c.drawString(50, y_pos, "Safety Reports:")
78+
y_pos -= 20
79+
80+
for report_file in sorted(report_files):
81+
try:
82+
with open(report_file) as f:
83+
report = json.load(f)
84+
85+
# Extract dataset name from filename
86+
dataset = report_file.stem.replace('safety_report_', '')
87+
88+
# Dataset header
89+
c.setFont("Helvetica-Bold", 11)
90+
c.drawString(50, y_pos, f"Dataset: {dataset}")
91+
y_pos -= 20
92+
93+
# Toxicity metrics
94+
c.setFont("Helvetica", 10)
95+
if 'tox' in report:
96+
tox_mean = report['tox'].get('mean', 'N/A')
97+
tox_p95 = report['tox'].get('p95', 'N/A')
98+
c.drawString(70, y_pos, f"Toxicity: mean={tox_mean:.4f}, p95={tox_p95:.4f}")
99+
y_pos -= 15
100+
101+
# Popularity bias
102+
if 'pop' in report:
103+
gini_base = report['pop'].get('gini_base', 'N/A')
104+
gini_new = report['pop'].get('gini_new', 'N/A')
105+
delta = report['pop'].get('delta', 'N/A')
106+
c.drawString(70, y_pos, f"Gini: base={gini_base:.4f}, new={gini_new:.4f}, delta={delta:.4f}")
107+
y_pos -= 15
108+
109+
# Gender parity (ml25m only)
110+
if 'parity' in report and report['parity']:
111+
male = report['parity'].get('male_rate', 0)
112+
female = report['parity'].get('female_rate', 0)
113+
gap = abs(male - female)
114+
c.drawString(70, y_pos, f"Gender: M={male:.4f}, F={female:.4f}, gap={gap:.4f}")
115+
y_pos -= 15
116+
117+
# Privacy
118+
if 'privacy' in report:
119+
priv = report['privacy']
120+
c.drawString(70, y_pos, f"Privacy leakage rate: {priv:.6f}")
121+
y_pos -= 25
122+
123+
# Check for new page if needed
124+
if y_pos < 100:
125+
c.showPage()
126+
y_pos = height - 50
127+
c.setFont("Helvetica-Bold", 12)
128+
c.drawString(50, y_pos, "Safety Reports (continued):")
129+
y_pos -= 30
130+
131+
except Exception as e:
132+
c.setFont("Helvetica-Italic", 10)
133+
c.drawString(70, y_pos, f"Error processing {report_file.name}: {str(e)}")
134+
y_pos -= 20
135+
136+
# Summary
137+
if y_pos < 150:
138+
c.showPage()
139+
y_pos = height - 50
140+
141+
c.setFont("Helvetica-Bold", 12)
142+
c.drawString(50, y_pos, "Summary:")
143+
y_pos -= 20
144+
c.setFont("Helvetica", 10)
145+
c.drawString(70, y_pos, f"Total reports processed: {len(report_files)}")
146+
147+
# Save the PDF
148+
c.save()
149+
print(f"PDF summary saved to {output_file}")
150+
return True
151+
152+
if __name__ == "__main__":
153+
import time
154+
success = generate_summary_pdf()
155+
sys.exit(0 if success else 1)

code/audit/privacy.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import pandas as pd, Levenshtein as lev
2+
3+
def leakage_rate(texts, user_ids):
4+
leaks = 0
5+
for t in texts:
6+
for uid in user_ids:
7+
if lev.distance(str(uid), t) <= 2:
8+
leaks += 1; break
9+
return leaks/len(texts)

0 commit comments

Comments
 (0)