-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathexample_lora_sft.py
More file actions
143 lines (117 loc) · 4.03 KB
/
example_lora_sft.py
File metadata and controls
143 lines (117 loc) · 4.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""
Example: GPTQ quantization + LoRA SFT + save/load
End-to-end demonstration of the LoRA SFT post-process workflow:
1. Quantize TinyLlama with GPTQ 4-bit (groupsize=128)
2. Apply LoRA SFT post-process (WikiText-2)
3. Evaluate PPL (original vs quantized+LoRA)
4. Save the LoRA-applied model via save_quantized_model_pt()
5. Load the saved model via load_quantized_model_pt()
6. Generate text with the loaded model to verify it works
Copyright 2025-2026 Fujitsu Ltd.
Author: Keiji Kimura
Usage:
python example/post_process/example_lora_sft.py
"""
import torch
from onecomp import (
GPTQ,
ModelConfig,
PostProcessLoraSFT,
Runner,
load_quantized_model_pt,
setup_logger,
)
setup_logger()
MODEL_ID = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
SAVE_DIR = "./tinyllama_gptq4_lora"
PROMPT = "Fujitsu is"
def generate_text(model, tokenizer, prompt, device, max_new_tokens=64):
"""Generate text from a prompt using the model."""
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
temperature=1.0,
repetition_penalty=1.2,
)
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated
# ================================================================
# Step 1: Quantize + LoRA SFT via Runner
# ================================================================
print("=" * 70)
print("Step 1: Quantize TinyLlama (GPTQ 4-bit) + LoRA SFT (WikiText-2)")
print("=" * 70)
model_config = ModelConfig(model_id=MODEL_ID, device="cuda:0")
gptq = GPTQ(wbits=4, groupsize=128)
post_process = PostProcessLoraSFT(
dataset_name="wikitext",
dataset_config_name="wikitext-2-raw-v1",
train_split="train",
text_column="text",
max_train_samples=256,
max_length=512,
epochs=2,
batch_size=2,
gradient_accumulation_steps=4,
lr=1e-4,
lora_r=16,
lora_alpha=32,
logging_steps=5,
)
runner = Runner(
model_config=model_config,
quantizer=gptq,
post_processes=[post_process],
max_length=512,
num_calibration_samples=128,
)
runner.run()
# ================================================================
# Step 2: Evaluate PPL (original vs quantized+LoRA)
# ================================================================
print("\n" + "=" * 70)
print("Step 2: Evaluate PPL")
print("=" * 70)
original_ppl, _, quantized_ppl = runner.calculate_perplexity(
original_model=True,
quantized_model=True,
)
print(f" Original model PPL: {original_ppl:.4f}")
print(f" Quantized + LoRA SFT model PPL: {quantized_ppl:.4f}")
# ================================================================
# Step 3: Save the LoRA-applied model (PyTorch .pt format)
# ================================================================
print("\n" + "=" * 70)
print(f"Step 3: Saving LoRA-applied model to {SAVE_DIR}")
print("=" * 70)
runner.save_quantized_model_pt(SAVE_DIR)
print(f"Model saved to: {SAVE_DIR}")
del runner
torch.cuda.empty_cache()
# ================================================================
# Step 4: Load the saved model
# ================================================================
print("\n" + "=" * 70)
print(f"Step 4: Loading model from {SAVE_DIR}")
print("=" * 70)
loaded_model, loaded_tokenizer = load_quantized_model_pt(SAVE_DIR)
print(f"Loaded model type : {type(loaded_model).__name__}")
print(f"Loaded model device: {next(loaded_model.parameters()).device}")
# ================================================================
# Step 5: Generate text with the loaded model
# ================================================================
print("\n" + "=" * 70)
print("Step 5: Generate text with loaded model")
print("=" * 70)
loaded_text = generate_text(
loaded_model,
loaded_tokenizer,
PROMPT,
device=next(loaded_model.parameters()).device,
)
print(f"\nPrompt : {PROMPT}")
print(f"Generated: {loaded_text}")
print("=" * 70)