-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathexample_lora_sft_knowledge.py
More file actions
144 lines (117 loc) · 4.24 KB
/
example_lora_sft_knowledge.py
File metadata and controls
144 lines (117 loc) · 4.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""
Example: Knowledge injection via LoRA SFT on a GPTQ-quantized model
Demonstrates how to teach a quantized model new knowledge using
LoRA SFT post-processing. The model learns about "OneCompression"
(a topic it has never seen) and can answer questions about it
after training.
Flow:
1. Quantize TinyLlama with GPTQ 4-bit (groupsize=128)
2. Build quantized model via create_quantized_model
3. Generate text BEFORE LoRA SFT (model does not know OneCompression)
4. Run LoRA SFT with OneCompression knowledge data
5. Generate text AFTER LoRA SFT (model can describe OneCompression)
6. Compare results side by side
Copyright 2025-2026 Fujitsu Ltd.
Author: Keiji Kimura
Usage:
python example/post_process/example_lora_sft_knowledge.py
"""
from pathlib import Path
import torch
from onecomp import GPTQ, ModelConfig, Runner, PostProcessLoraSFT, setup_logger
setup_logger()
MODEL_ID = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
KNOWLEDGE_DATA = str(Path(__file__).parent / "onecomp_knowledge.jsonl")
PROMPT = "Q: What is OneCompression?\nA:"
def generate_text(model, tokenizer, prompt, device, max_new_tokens=128):
"""Generate text from a prompt using the model."""
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
temperature=1.0,
repetition_penalty=1.2,
)
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated
# ================================================================
# Step 1: Quantize the model with GPTQ 4-bit
# ================================================================
print("=" * 70)
print("Step 1: Quantizing TinyLlama with GPTQ 4-bit (groupsize=128)")
print("=" * 70)
model_config = ModelConfig(model_id=MODEL_ID, device="cuda:0")
gptq = GPTQ(wbits=4, groupsize=128)
runner = Runner(
model_config=model_config,
quantizer=gptq,
max_length=512,
num_calibration_samples=128,
)
runner.run()
# ================================================================
# Step 2: Build quantized model
# ================================================================
print("\n" + "=" * 70)
print("Step 2: Building quantized model via create_quantized_model")
print("=" * 70)
model, tokenizer = runner.create_quantized_model(
pack_weights=False,
use_gemlite=False,
)
# ================================================================
# Step 3: Generate BEFORE LoRA SFT
# ================================================================
print("\n" + "=" * 70)
print("Step 3: Generating text BEFORE LoRA SFT")
print("=" * 70)
model.to("cuda:0")
before_text = generate_text(model, tokenizer, PROMPT, "cuda:0")
model.to("cpu")
torch.cuda.empty_cache()
print(f"\nPrompt: {PROMPT}")
print(f"Response:\n{before_text}")
# ================================================================
# Step 4: Run LoRA SFT with OneCompression knowledge
# ================================================================
print("\n" + "=" * 70)
print("Step 4: Running LoRA SFT with OneCompression knowledge data")
print("=" * 70)
post_process = PostProcessLoraSFT(
data_files=KNOWLEDGE_DATA,
max_length=256,
epochs=50,
batch_size=2,
gradient_accumulation_steps=1,
lr=3e-4,
lora_r=16,
lora_alpha=32,
logging_steps=5,
)
post_process.run(model, model_config)
# ================================================================
# Step 5: Generate AFTER LoRA SFT
# ================================================================
print("\n" + "=" * 70)
print("Step 5: Generating text AFTER LoRA SFT")
print("=" * 70)
model.to("cuda:0")
after_text = generate_text(model, tokenizer, PROMPT, "cuda:0")
model.to("cpu")
torch.cuda.empty_cache()
print(f"\nPrompt: {PROMPT}")
print(f"Response:\n{after_text}")
# ================================================================
# Step 6: Compare results
# ================================================================
print("\n" + "=" * 70)
print("Comparison: Before vs After LoRA SFT")
print("=" * 70)
print(f"\nPrompt: {PROMPT}")
print(f"\n--- BEFORE LoRA SFT ---")
print(before_text)
print(f"\n--- AFTER LoRA SFT ---")
print(after_text)
print("=" * 70)