diff --git a/autochain/agent/conversational_agent/conversational_agent.py b/autochain/agent/conversational_agent/conversational_agent.py index cf93324..a1aa3b4 100644 --- a/autochain/agent/conversational_agent/conversational_agent.py +++ b/autochain/agent/conversational_agent/conversational_agent.py @@ -165,6 +165,7 @@ def plan( tool_strings = "\n\n".join( [f"> {tool.name}: \n{tool.description}" for tool in self.tools] ) + inputs = { "tool_names": tool_names, "tools": tool_strings, @@ -175,11 +176,14 @@ def plan( final_prompt = self.format_prompt( self.prompt_template, intermediate_steps, **inputs ) + logger.info(f"\nPlanning Input: {final_prompt[0].content} \n") full_output: Generation = self.llm.generate(final_prompt).generations[0] + agent_output: Union[AgentAction, AgentFinish] = self.output_parser.parse( - full_output.message + full_output.message, + self.llm ) print(f"Planning output: \n{repr(full_output.message.content)}", Fore.YELLOW) diff --git a/autochain/agent/conversational_agent/output_parser.py b/autochain/agent/conversational_agent/output_parser.py index b0518ec..d2a6169 100644 --- a/autochain/agent/conversational_agent/output_parser.py +++ b/autochain/agent/conversational_agent/output_parser.py @@ -5,13 +5,14 @@ from autochain.agent.message import BaseMessage from autochain.agent.structs import AgentAction, AgentFinish, AgentOutputParser +from autochain.models.base import BaseLanguageModel from autochain.errors import OutputParserException from autochain.utils import print_with_color class ConvoJSONOutputParser(AgentOutputParser): - def parse(self, message: BaseMessage) -> Union[AgentAction, AgentFinish]: - response = self.load_json_output(message) + def parse(self, message: BaseMessage, llm: BaseLanguageModel) -> Union[AgentAction, AgentFinish]: + response = self.load_json_output(message, llm) action_name = response.get("tool", {}).get("name") action_args = response.get("tool", {}).get("args") diff --git a/autochain/agent/structs.py b/autochain/agent/structs.py index 435fd02..f7efcb5 100644 --- a/autochain/agent/structs.py +++ b/autochain/agent/structs.py @@ -1,16 +1,15 @@ import json -import re from abc import abstractmethod from typing import Union, Any, Dict, List +from colorama import Fore -from autochain.models.base import Generation - -from autochain.models.chat_openai import ChatOpenAI +from autochain.models.base import BaseLanguageModel from pydantic import BaseModel +from autochain.models.base import Generation from autochain.agent.message import BaseMessage, UserMessage from autochain.chain import constants -from autochain.errors import OutputParserException +from autochain.utils import print_with_color class AgentAction(BaseModel): @@ -55,30 +54,94 @@ def format_output(self) -> Dict[str, Any]: class AgentOutputParser(BaseModel): - @staticmethod - def load_json_output(message: BaseMessage) -> Dict[str, Any]: - """If the message contains a json response, try to parse it into dictionary""" + + def load_json_output( + self, + message: BaseMessage, + llm: BaseLanguageModel, + max_retry=3 + ) -> Dict[str, Any]: + """Try to parse JSON response from the message content.""" text = message.content - clean_text = "" + clean_text = self._extract_json_text(text) try: - clean_text = text[text.index("{") : text.rindex("}") + 1].strip() response = json.loads(clean_text) except Exception: - llm = ChatOpenAI(temperature=0) - message = [ - UserMessage( - content=f"""Fix the following json into correct format -```json -{clean_text} -``` -""" - ) - ] - full_output: Generation = llm.generate(message).generations[0] - response = json.loads(full_output.message.content) + print_with_color( + 'Generating JSON format attempt FAILED! Trying Again...', + Fore.RED + ) + message = self._fix_message(clean_text) + response = self._attempt_fix_and_generate( + message, + llm, + max_retry, + attempt=0 + ) return response + + @staticmethod + def _fix_message(clean_text: str) -> UserMessage: + ''' + If the response from model is not proper, this function should + iteratively construct better response until response becomes json parseable + ''' + + # TO DO + # Construct this message better in order to make it better iteratively by + # _attempt_fix_and_generate recursive function + message = UserMessage( + content=f""" + Fix the following json into correct format + ```json + {clean_text} + ``` + """ + ) + return message + + @staticmethod + def _extract_json_text(text: str) -> str: + """Extract JSON text from the input string.""" + clean_text = "" + try: + clean_text = text[text.index("{") : text.rindex("}") + 1].strip() + except Exception: + clean_text = text + return clean_text + + def _attempt_fix_and_generate( + self, + message: BaseMessage, + llm: BaseLanguageModel, + max_retry: int, + attempt: int + ) -> Dict[str, Any]: + + """Attempt to fix JSON format using model generation recursively.""" + if attempt >= max_retry: + raise ValueError( + """ + Max retry reached. Model is unable to generate proper JSON output. + Try with another Model! + """ + ) + + full_output: Generation = llm.generate([message]).generations[0] + + try: + response = json.loads(full_output.message.content) + return response + except Exception: + print_with_color( + 'Generating JSON format attempt FAILED! Trying Again...', + Fore.RED + ) + clean_text = self._extract_json_text(full_output.message.content) + message = self._fix_message(clean_text) + return self._attempt_fix_and_generate(message, llm, max_retry, attempt=attempt + 1) @abstractmethod def parse(self, message: BaseMessage) -> Union[AgentAction, AgentFinish]: