2424# Agent, dataset, and eval_fn (same as custom_agent_example.py)
2525# ---------------------------------------------------------------------------
2626
27+
2728class MyAgent :
2829 def __init__ (self , models ):
2930 self .client = OpenAI ()
3031 self .planner_model = models ["planner" ]
3132 self .solver_model = models ["solver" ]
3233
3334 def run (self , input_data ):
34- plan = self .client .chat .completions .create (
35- model = self .planner_model ,
36- messages = [
37- {"role" : "system" , "content" : "Create a brief plan to answer the question." },
38- {"role" : "user" , "content" : input_data },
39- ],
40- ).choices [0 ].message .content
41-
42- answer = self .client .chat .completions .create (
43- model = self .solver_model ,
44- messages = [
45- {"role" : "system" , "content" : f"Follow this plan and answer concisely:\n { plan } " },
46- {"role" : "user" , "content" : input_data },
47- ],
48- ).choices [0 ].message .content
35+ plan = (
36+ self .client .chat .completions .create (
37+ model = self .planner_model ,
38+ messages = [
39+ {
40+ "role" : "system" ,
41+ "content" : "Create a brief plan to answer the question." ,
42+ },
43+ {"role" : "user" , "content" : input_data },
44+ ],
45+ )
46+ .choices [0 ]
47+ .message .content
48+ )
49+
50+ answer = (
51+ self .client .chat .completions .create (
52+ model = self .solver_model ,
53+ messages = [
54+ {
55+ "role" : "system" ,
56+ "content" : f"Follow this plan and answer concisely:\n { plan } " ,
57+ },
58+ {"role" : "user" , "content" : input_data },
59+ ],
60+ )
61+ .choices [0 ]
62+ .message .content
63+ )
4964 return answer
5065
5166
@@ -71,19 +86,22 @@ def eval_fn(expected, actual):
7186# Selection algorithms
7287# ---------------------------------------------------------------------------
7388
89+
7490def run_auto ():
7591 """method="auto" — automatically picks the best algorithm (default)."""
7692 selector = ModelSelector (
77- agent = MyAgent , models = models , eval_fn = eval_fn , dataset = dataset ,
78- method = "auto" ,
93+ agent = MyAgent , models = models , eval_fn = eval_fn , dataset = dataset , method = "auto" ,
7994 )
8095 return selector .select_best (parallel = True )
8196
8297
8398def run_random ():
8499 """method="random" — evaluate a random subset of combinations."""
85100 selector = ModelSelector (
86- agent = MyAgent , models = models , eval_fn = eval_fn , dataset = dataset ,
101+ agent = MyAgent ,
102+ models = models ,
103+ eval_fn = eval_fn ,
104+ dataset = dataset ,
87105 method = "random" ,
88106 sample_fraction = 0.5 , # evaluate 50% of all combinations
89107 )
@@ -93,7 +111,10 @@ def run_random():
93111def run_hill_climbing ():
94112 """method="hill_climbing" — greedy search using model quality/speed rankings."""
95113 selector = ModelSelector (
96- agent = MyAgent , models = models , eval_fn = eval_fn , dataset = dataset ,
114+ agent = MyAgent ,
115+ models = models ,
116+ eval_fn = eval_fn ,
117+ dataset = dataset ,
97118 method = "hill_climbing" ,
98119 batch_size = 4 , # number of neighbors to evaluate per step
99120 )
@@ -103,7 +124,10 @@ def run_hill_climbing():
103124def run_arm_elimination ():
104125 """method="arm_elimination" — eliminates statistically dominated combinations early."""
105126 selector = ModelSelector (
106- agent = MyAgent , models = models , eval_fn = eval_fn , dataset = dataset ,
127+ agent = MyAgent ,
128+ models = models ,
129+ eval_fn = eval_fn ,
130+ dataset = dataset ,
107131 method = "arm_elimination" ,
108132 )
109133 return selector .select_best (parallel = True )
@@ -112,7 +136,10 @@ def run_arm_elimination():
112136def run_epsilon_lucb ():
113137 """method="epsilon_lucb" — stops when the best arm is identified within epsilon."""
114138 selector = ModelSelector (
115- agent = MyAgent , models = models , eval_fn = eval_fn , dataset = dataset ,
139+ agent = MyAgent ,
140+ models = models ,
141+ eval_fn = eval_fn ,
142+ dataset = dataset ,
116143 method = "epsilon_lucb" ,
117144 epsilon = 0.05 , # acceptable gap from the true best
118145 )
@@ -122,7 +149,10 @@ def run_epsilon_lucb():
122149def run_threshold ():
123150 """method="threshold" — classify combinations as above/below a quality threshold."""
124151 selector = ModelSelector (
125- agent = MyAgent , models = models , eval_fn = eval_fn , dataset = dataset ,
152+ agent = MyAgent ,
153+ models = models ,
154+ eval_fn = eval_fn ,
155+ dataset = dataset ,
126156 method = "threshold" ,
127157 threshold = 0.8 , # minimum acceptable accuracy
128158 )
@@ -132,7 +162,10 @@ def run_threshold():
132162def run_lm_proposal ():
133163 """method="lm_proposal" — use a proposer LLM to shortlist promising combinations."""
134164 selector = ModelSelector (
135- agent = MyAgent , models = models , eval_fn = eval_fn , dataset = dataset ,
165+ agent = MyAgent ,
166+ models = models ,
167+ eval_fn = eval_fn ,
168+ dataset = dataset ,
136169 method = "lm_proposal" ,
137170 )
138171 return selector .select_best (parallel = True )
@@ -141,7 +174,10 @@ def run_lm_proposal():
141174def run_bayesian ():
142175 """method="bayesian" — GP-based Bayesian optimization (requires agentopt[bayesian])."""
143176 selector = ModelSelector (
144- agent = MyAgent , models = models , eval_fn = eval_fn , dataset = dataset ,
177+ agent = MyAgent ,
178+ models = models ,
179+ eval_fn = eval_fn ,
180+ dataset = dataset ,
145181 method = "bayesian" ,
146182 batch_size = 4 ,
147183 )
0 commit comments