Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 22 additions & 13 deletions helpers/hllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,11 @@ def _call_api_sync(
return completion_obj


@hcacsimp.simple_cache(write_through=True, exclude_keys=["client", "cache_mode"])
@hcacsimp.simple_cache(
cache_type="pickle",
write_through=True,
exclude_keys=["client", "cache_mode", "cost_tracker"],
)
def _call_structured_api_sync(
# pylint: disable=unused-argument
# This is needed to support caching.
Expand All @@ -269,16 +273,18 @@ def _call_structured_api_sync(
response_format: type[T],
*,
images_as_base64: Optional[Tuple[str, ...]] = None,
cost_tracker: Optional[hllmcost.LLMCostTracker] = None,
print_cost: bool = False,
**create_kwargs,
) -> Any:
) -> T:
"""
Make a non-streaming structured API call.

See `get_structured_completion()` for parameter descriptions.

:param client: LLM client
:param response_format: expected structured output format
:return: OpenAI Response object with parsed output
:return: parsed output as the specified Pydantic model
"""
user_input = build_responses_input(
user_prompt, images_as_base64=images_as_base64
Expand All @@ -291,7 +297,16 @@ def _call_structured_api_sync(
text_format=response_format,
**create_kwargs,
)
return response
# Extract the parsed output.
parsed_output: T = response.output_parsed
# Track costs.
if cost_tracker is not None:
hdbg.dassert_isinstance(cost_tracker, hllmcost.LLMCostTracker)
cost = cost_tracker.calculate_cost(response)
cost_tracker.accumulate_cost(cost)
if print_cost:
_LOG.info("cost=%.6f", cost)
return parsed_output


# #############################################################################
Expand Down Expand Up @@ -572,7 +587,7 @@ def get_structured_completion(
f"Got provider_name='{llm_client.provider_name}'."
)
# Retrieve a structured response.
response = _call_structured_api_sync(
parsed_output: T = _call_structured_api_sync(
cache_mode=cache_mode,
client=llm_client.client,
model=llm_client.model,
Expand All @@ -581,16 +596,10 @@ def get_structured_completion(
temperature=temperature,
response_format=response_format,
images_as_base64=images_as_base64,
cost_tracker=cost_tracker,
print_cost=print_cost,
**create_kwargs,
)
parsed_output: T = response.output_parsed
# Track costs.
if cost_tracker is not None:
hdbg.dassert_isinstance(cost_tracker, hllmcost.LLMCostTracker)
cost = cost_tracker.calculate_cost(response)
cost_tracker.accumulate_cost(cost)
if print_cost:
_LOG.info("cost=%.6f", cost)
return parsed_output


Expand Down
Loading