diff --git a/helpers/hllm.py b/helpers/hllm.py index b306d83e7..5375121ba 100644 --- a/helpers/hllm.py +++ b/helpers/hllm.py @@ -256,7 +256,11 @@ def _call_api_sync( return completion_obj -@hcacsimp.simple_cache(write_through=True, exclude_keys=["client", "cache_mode"]) +@hcacsimp.simple_cache( + cache_type="pickle", + write_through=True, + exclude_keys=["client", "cache_mode", "cost_tracker"], +) def _call_structured_api_sync( # pylint: disable=unused-argument # This is needed to support caching. @@ -269,8 +273,10 @@ def _call_structured_api_sync( response_format: type[T], *, images_as_base64: Optional[Tuple[str, ...]] = None, + cost_tracker: Optional[hllmcost.LLMCostTracker] = None, + print_cost: bool = False, **create_kwargs, -) -> Any: +) -> T: """ Make a non-streaming structured API call. @@ -278,7 +284,7 @@ def _call_structured_api_sync( :param client: LLM client :param response_format: expected structured output format - :return: OpenAI Response object with parsed output + :return: parsed output as the specified Pydantic model """ user_input = build_responses_input( user_prompt, images_as_base64=images_as_base64 @@ -291,7 +297,16 @@ def _call_structured_api_sync( text_format=response_format, **create_kwargs, ) - return response + # Extract the parsed output. + parsed_output: T = response.output_parsed + # Track costs. + if cost_tracker is not None: + hdbg.dassert_isinstance(cost_tracker, hllmcost.LLMCostTracker) + cost = cost_tracker.calculate_cost(response) + cost_tracker.accumulate_cost(cost) + if print_cost: + _LOG.info("cost=%.6f", cost) + return parsed_output # ############################################################################# @@ -572,7 +587,7 @@ def get_structured_completion( f"Got provider_name='{llm_client.provider_name}'." ) # Retrieve a structured response. - response = _call_structured_api_sync( + parsed_output: T = _call_structured_api_sync( cache_mode=cache_mode, client=llm_client.client, model=llm_client.model, @@ -581,16 +596,10 @@ def get_structured_completion( temperature=temperature, response_format=response_format, images_as_base64=images_as_base64, + cost_tracker=cost_tracker, + print_cost=print_cost, **create_kwargs, ) - parsed_output: T = response.output_parsed - # Track costs. - if cost_tracker is not None: - hdbg.dassert_isinstance(cost_tracker, hllmcost.LLMCostTracker) - cost = cost_tracker.calculate_cost(response) - cost_tracker.accumulate_cost(cost) - if print_cost: - _LOG.info("cost=%.6f", cost) return parsed_output