Skip to content

Commit d1d795a

Browse files
WenyuehWenyuehCopilot
authored
wall time and cost (#27)
* wall time and cost * avoid overcounting * Update src/agentopt/model_selection/base.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Wenyueh <norahua1996@outlook.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 1b61bcc commit d1d795a

1 file changed

Lines changed: 34 additions & 1 deletion

File tree

  • src/agentopt/model_selection

src/agentopt/model_selection/base.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@ class SelectionResults(BaseModel):
8686
"""Results from model selection."""
8787

8888
results: List[ModelResult] = Field(default_factory=list)
89+
selection_wall_time_seconds: Optional[float] = None
90+
selection_cost: Optional[float] = Field(
91+
default=None,
92+
description="Total selection cost in USD.",
93+
)
8994

9095
def __iter__(self):
9196
return iter(self.results)
@@ -294,6 +299,16 @@ def row(
294299
)
295300
lines.append("")
296301

302+
# Selection overhead
303+
parts = []
304+
if self.selection_wall_time_seconds is not None:
305+
parts.append(f"{self.selection_wall_time_seconds:.2f}s")
306+
if self.selection_cost is not None:
307+
parts.append(f"${self.selection_cost:.6f}")
308+
if parts:
309+
lines.append(f"{pad} Selection overhead: {', '.join(parts)}")
310+
lines.append("")
311+
297312
return "\n".join(lines)
298313

299314
def print_summary(self) -> None:
@@ -667,11 +682,29 @@ def select_best(
667682
Returns:
668683
SelectionResults containing all model evaluation results.
669684
"""
685+
record_offset = len(self._tracker.get_records())
686+
t0 = time.time()
670687
try:
671-
return self._run_selection(parallel, max_concurrent)
688+
result = self._run_selection(parallel, max_concurrent)
672689
finally:
673690
self._tracker.stop()
674691

692+
result.selection_wall_time_seconds = time.time() - t0
693+
694+
# Cost: only non-cached calls made during this run
695+
input_tokens: Dict[str, int] = {}
696+
output_tokens: Dict[str, int] = {}
697+
for r in self._tracker.get_records()[record_offset:]:
698+
if r.cached:
699+
continue
700+
input_tokens[r.model] = input_tokens.get(r.model, 0) + r.prompt_tokens
701+
output_tokens[r.model] = output_tokens.get(r.model, 0) + r.completion_tokens
702+
result.selection_cost = compute_price(
703+
input_tokens, output_tokens, custom_prices=self._custom_prices,
704+
)
705+
706+
return result
707+
675708
@abstractmethod
676709
def _run_selection(
677710
self, parallel: bool = False, max_concurrent: int = 20,

0 commit comments

Comments
 (0)