Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 68 additions & 2 deletions src/chains/qvkg_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def __init__(self, config, default_k=10):
self.config = config
self.kg_chain = KGChain(config=config)
self.qv_chain = QuestionLookupChain(config=config)
self.llm = LLMFactory.get_llm(config=config)
self.raw_llm = LLMFactory.get_raw_llm(config=config)
self.llm = self.raw_llm | LLMFactory.strip_thought
self.default_lookup_k = default_k
self.langfuse_client = Langfuse(secret_key=config.LANGFUSE_SECRET_KEY,
public_key=config.LANGFUSE_PUBLIC_KEY,
Expand Down Expand Up @@ -146,7 +147,7 @@ def as_generative_chain(self, lookup_parameters=None):
"context": x["context"],
"chat_history": x["chat_history"],
"user_persona": x["user_persona"]
}) | self.COMBINED_ANSWER_PROMPT | self.llm.with_config(name="answer_generation") | StrOutputParser(),
}) | RunnableLambda(self._limit_context_tokens) | self.COMBINED_ANSWER_PROMPT | self.llm.with_config(name="answer_generation") | StrOutputParser(),
"extra": RunnableLambda(lambda x: {"knowledge_graph": x["kg_extra"].get("knowledge_graph", {})}),
"prompt": RunnableLambda(lambda x: {
"input": x["input"],
Expand Down Expand Up @@ -183,10 +184,75 @@ def _create_answer_generation_prompt(self, prompt_name):
MessagesPlaceholder(variable_name="chat_history"),
("user", "{input}")
])

def _get_raw_from_langfuse(self, prompt_name: str) -> str:
"""Gets raw string for of prompts in langfuse"""
return self.langfuse_client.get_prompt(prompt_name).prompt

def _limit_context_tokens(self, inputs: Dict) -> Dict:
"""
Checks the token count of the prompt constructed from inputs.
If it exceeds the limit, iteratively removes variables from studies, then studies themselves until it fits.
"""
MAX_TOKENS = 16000 # Leave some buffer for response

# We need to format the messages to count tokens accurately
def get_token_count(current_inputs):
messages = self.COMBINED_ANSWER_PROMPT.format_messages(**current_inputs)
try:
return self.raw_llm.get_num_tokens_from_messages(messages)
except Exception:
# Fallback if the LLM doesn't support token counting
# Rough estimate: 4 chars per token
return sum(len(m.content) for m in messages) / 4

if get_token_count(inputs) <= MAX_TOKENS:
return inputs

# Parse context
try:
context_str = inputs["context"]
# combine_xml_outputs returns <studies>...</studies>
root = ET.fromstring(context_str)

# Loop until we are under the limit
while get_token_count(inputs) > MAX_TOKENS:
studies = root.findall('study')
if not studies:
break

# Check if we have any variables to remove
# We want to find the study with the most variables
study_with_most_vars = None
max_vars = 0

for study in studies:
variables_container = study.find('variables')
if variables_container is not None:
vars_in_study = variables_container.findall('variable')
if len(vars_in_study) > max_vars:
max_vars = len(vars_in_study)
study_with_most_vars = study

# If we found a study with variables (and count > 0), remove one
if study_with_most_vars is not None and max_vars > 0:
variables_container = study_with_most_vars.find('variables')
vars_in_study = variables_container.findall('variable')
# Remove the last variable
variables_container.remove(vars_in_study[-1])
else:
# No variables left in any study, remove the last study
root.remove(studies[-1])

# Reconstruct context
new_context = ET.tostring(root, encoding='unicode')
inputs["context"] = new_context

except ET.ParseError:
# If we can't parse, we can't intelligently reduce.
pass

return inputs


if __name__ == "__main__":
Expand Down
Loading