-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathintent_classification_example.py
More file actions
106 lines (83 loc) · 3.46 KB
/
intent_classification_example.py
File metadata and controls
106 lines (83 loc) · 3.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""
Simple Intent Classification Example
This example shows how to use ProTeGi to optimize prompts for customer support
intent classification. It takes a basic prompt and improves its accuracy
through automated optimization.
Setup:
1. Install requirements: pip install -r requirements.txt
2. Set API key: export ANTHROPIC_API_KEY="your_key"
3. Run: python intent_classification_example.py
"""
import os
import sys
from pathlib import Path
# Add project root to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from llm import ProviderFactory, LLMConfig
from evaluation import ClassificationDataset, DatasetItem, PromptEvaluator
from optimization import (
GradientGenerator,
PromptEditor,
BanditBeamSearch,
BanditBeamConfig
)
def create_simple_dataset():
"""Create a small dataset with common customer support intents."""
return ClassificationDataset(
name="customer_intents",
items=[
# Refund requests
DatasetItem("I want my money back", "refund"),
DatasetItem("Can I return this product?", "refund"),
DatasetItem("This doesn't work, please refund", "refund"),
# Technical support
DatasetItem("The app keeps crashing", "technical_support"),
DatasetItem("I can't log in to my account", "technical_support"),
DatasetItem("Getting error message when I save", "technical_support"),
# Billing questions
DatasetItem("My credit card was charged twice", "billing"),
DatasetItem("Need to update payment method", "billing"),
DatasetItem("What's this charge on my statement?", "billing"),
# General inquiries
DatasetItem("What are your shipping options?", "general_inquiry"),
DatasetItem("Do you have this in different colors?", "general_inquiry"),
DatasetItem("When will new products be available?", "general_inquiry"),
]
)
def main():
# Check for API key
api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key:
print("Please set ANTHROPIC_API_KEY environment variable")
return
# Setup provider and dataset
config = LLMConfig(model="claude-3-5-haiku-20241022")
provider = ProviderFactory.create("anthropic", api_key=api_key, config=config)
dataset = create_simple_dataset()
# Initial simple prompt
initial_prompt = "What is the customer asking for?"
# Evaluate initial prompt
evaluator = PromptEvaluator(provider)
initial_result = evaluator.evaluate(initial_prompt, dataset)
print(f"Initial prompt: '{initial_prompt}'")
print(f"Initial F1 score: {initial_result.metrics.f1:.3f}")
print(f"Initial accuracy: {initial_result.metrics.accuracy:.3f}")
# Setup ProTeGi optimization
generator = GradientGenerator(provider)
editor = PromptEditor(provider)
config = BanditBeamConfig(
beam_width=2,
num_iterations=2,
variants_per_candidate=2
)
protegi = BanditBeamSearch(evaluator, generator, editor, config)
# Optimize prompt
print("\nOptimizing prompt...")
best = protegi.optimize(initial_prompt, dataset, metric="f1")
print(f"\nOptimized prompt: '{best.prompt}'")
print(f"Optimized F1 score: {best.mean_score:.3f}")
improvement = ((best.mean_score - initial_result.metrics.f1) /
initial_result.metrics.f1) * 100
print(f"Improvement: {improvement:.1f}%")
if __name__ == "__main__":
main()