-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcreate_vector_optimized.py
More file actions
113 lines (97 loc) · 3.87 KB
/
create_vector_optimized.py
File metadata and controls
113 lines (97 loc) · 3.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import json
import re
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from dotenv import load_dotenv
from prompts import COMPARISON_SYSTEM_PROMPT
import steering_opt
import argparse
load_dotenv()
parser = argparse.ArgumentParser(description="Optimize steering vector with awareness setting.")
parser.add_argument('--setting', type=str, default="unaware", choices=["aware", "unaware"], help="Prompting setting: aware or unaware.")
parser.add_argument('--layer', type=str, default="14")
parser.add_argument('--lr', type=float, default=0.1)
parser.add_argument('--max_iters', type=int, default=20)
parser.add_argument('--affine_rank', type=int, default=-1)
args = parser.parse_args()
args.affine_rank = args.affine_rank if args.affine_rank > 0 else None
TARGET = "llama3.1-8b-instruct"
def load_jsonl(path):
with open(path, "r") as f:
return [json.loads(line) for line in f]
MODEL_ID = "meta-llama/Meta-Llama-3.1-8B-Instruct"
HF_TOKEN = os.getenv("HFTOKEN")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
token=HF_TOKEN
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model.resize_token_embeddings(len(tokenizer))
model.config.pad_token_id = tokenizer.pad_token_id
model.eval()
tokenizer.pad_token_id = tokenizer.eos_token_id
# Remove chat template logic, use plain prompts
flip_output = lambda a: "1" if a == "2" else "2"
# Use correct path for bias, lsp examples based on setting
bias_examples_path = os.path.join("steering_inputs", args.setting, "bias_examples.jsonl")
bias_examples = []
with open(bias_examples_path, "r") as f:
for line in f:
bias_examples.append(json.loads(line))
agreement_examples_path = os.path.join("steering_inputs", args.setting, "agreement_examples.jsonl")
agreement_examples = []
with open(agreement_examples_path, "r") as f:
for line in f:
agreement_examples.append(json.loads(line))
lsp_examples_path = os.path.join("steering_inputs", args.setting, "lsp_examples.jsonl")
lsp_examples = []
with open(lsp_examples_path, "r") as f:
for line in f:
lsp_examples.append(json.loads(line))
examples = bias_examples # + agreement_examples + lsp_examples
def chat_template(prompt, post_script="<|start_header_id|>assistant<|end_header_id|>\n\n"):
prompt = tokenizer.apply_chat_template([
{
"role": "system",
"content": COMPARISON_SYSTEM_PROMPT
},
{
"role": "user",
"content": prompt
}],tokenize=False) + post_script
return prompt
# Optimization process
print(f"Loaded {len(examples)} examples for setting: {args.setting}")
datapoints = []
for example in tqdm(examples):
text = chat_template(example['prompt'])
src_completion = flip_output(example['unbiased_output'])
dst_completion = example['unbiased_output']
datapoints.append(
steering_opt.TrainingDatapoint(
text,
src_completions=[src_completion],
dst_completions=[dst_completion]
)
)
layer = min(int(args.layer), model.config.num_hidden_layers)
vec_data = steering_opt.optimize_completion(
model, datapoints, layer, tokenizer=tokenizer,
lr=args.lr, max_iters=args.max_iters, use_transformer_lens=False,
do_target_loss_avg=False, affine_rank=args.affine_rank, return_loss=True,
target_loss=None, target_loss_target_iters=5,
debug=True
)
print("Steering vector optimized.")
# Save the vector to workspace-relative path
out_path = os.path.join("vectors", "optimization", f"steering_vector_{args.setting}_affine_r1.pkl")
os.makedirs(os.path.dirname(out_path), exist_ok=True)
import pickle
with open(out_path, "wb") as f:
pickle.dump(vec_data, f)
print(f"Optimized steering vector saved to {out_path}")