Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/query_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from config import debug_print
from static_mutation import prune_equivalent_codes
from assertion_rewriter import rewrite_assert
import pydantic
from openai.types.chat import ChatCompletion


def gen_and_prune_codes(client, prog_data, tests_in_ctxt, token_counter=None):
Expand Down Expand Up @@ -315,7 +317,9 @@ def get_or_create_codex_response(client, prompt_val, best_of_val, temp_val, echo

if config.codex_cache_file is not None and str(k) in config.codex_query_response_log:
config.skip_codex_query_cnt = config.skip_codex_query_cnt + 1
resp = config.codex_query_response_log[str(k)][1]
# Need to parse into openai's ChatCompletion response type
# because it is expected at call sites.
resp = ChatCompletion.model_validate(config.codex_query_response_log[str(k)][1])
debug_print(f"Cached response for {k} is {resp}")
return resp
assert best_of_val <= max_suggestions
Expand Down Expand Up @@ -367,6 +371,11 @@ def get_or_create_codex_response(client, prompt_val, best_of_val, temp_val, echo
print("Current Tokens:, ", query_response.usage.total_tokens, "\tUsed tokens: ",
token_counter.used_tokens, "\tToken limit: ", token_counter.token_limit,
"\tSo far generated: ", len(response.choices))
v = (k, response, current_time)
# Depending on the version of openai, the response is a pydantic object or a dict
# In the case of pydantic, we need to convert it to a dict so `json.dumps` works.
response_dict = (
response.model_dump() if isinstance(response, pydantic.BaseModel) else response
)
v = (k, response_dict, current_time)
config.codex_query_response_log[str(k)] = v
return response