diff --git a/sentient/__init__.py b/sentient/__init__.py index 1517ae1..94bda63 100644 --- a/sentient/__init__.py +++ b/sentient/__init__.py @@ -7,29 +7,30 @@ class Sentient: def __init__(self): self.orchestrator = None - - def _create_state_to_agent_map(self, provider: str, model: str, custom_base_url: str = None): + + def _create_state_to_agent_map(self, provider: str, model: str, custom_base_url: str, max_retries: int): provider_instance = get_provider(provider, custom_base_url) return { - State.BASE_AGENT: Agent(provider=provider_instance, model_name=model), + State.BASE_AGENT: Agent(provider=provider_instance, model_name=model, max_retries=max_retries), } - async def _initialize(self, provider: str, model: str, custom_base_url: str = None): + async def _initialize(self, provider: str, model: str, custom_base_url: str, max_retries: int): if not self.orchestrator: - state_to_agent_map = self._create_state_to_agent_map(provider, model, custom_base_url) - self.orchestrator = Orchestrator(state_to_agent_map=state_to_agent_map) + state_to_agent_map = self._create_state_to_agent_map(provider, model, custom_base_url, max_retries) + self.orchestrator = Orchestrator(state_to_agent_map=state_to_agent_map, model=model) await self.orchestrator.start() async def invoke( - self, - goal: str, provider: str = "openai", - model: str = "gpt-4o-2024-08-06", - task_instructions: str = None, - custom_base_url: str = None - ): + self, + max_retries: int, + goal: str, provider: str = "openai", + model: str = "gpt-4o-2024-08-06", + task_instructions: str = None, + custom_base_url: str = None, + ): if task_instructions: ltm.set_task_instructions(task_instructions) - await self._initialize(provider, model, custom_base_url) + await self._initialize(provider, model, custom_base_url, max_retries) result = await self.orchestrator.execute_command(goal) return result diff --git a/sentient/__main__.py b/sentient/__main__.py index 5740bb2..dd5f9f5 100644 --- a/sentient/__main__.py +++ b/sentient/__main__.py @@ -8,7 +8,7 @@ async def main(): # Define state machine state_to_agent_map = { - State.BASE_AGENT: Agent(), + State.BASE_AGENT: Agent(max_retries=3), } orchestrator = Orchestrator(state_to_agent_map=state_to_agent_map) diff --git a/sentient/core/agent/agent.py b/sentient/core/agent/agent.py index 926fcbe..c7dad74 100644 --- a/sentient/core/agent/agent.py +++ b/sentient/core/agent/agent.py @@ -9,7 +9,7 @@ class Agent(BaseAgent): - def __init__(self, provider:LLMProvider, model_name: str): + def __init__(self, provider: LLMProvider, model_name: str, max_retries: int): self.name = "sentient" self.ltm = None self.ltm = self.__get_ltm() @@ -22,6 +22,7 @@ def __init__(self, provider:LLMProvider, model_name: str): keep_message_history=False, provider=provider, model_name=model_name, + max_retries=max_retries, ) @staticmethod diff --git a/sentient/core/agent/base.py b/sentient/core/agent/base.py index 886c9f2..01a202f 100644 --- a/sentient/core/agent/base.py +++ b/sentient/core/agent/base.py @@ -16,6 +16,7 @@ from sentient.utils.providers import get_provider, LLMProvider class BaseAgent: + def __init__( self, name: str, @@ -26,9 +27,11 @@ def __init__( keep_message_history: bool = True, provider: LLMProvider = None, model_name: str = None, + max_retries: int = 3 ): - # Metdata + # Metadata self.agent_name = name + self.max_retries = max_retries # Messages self.system_prompt = system_prompt @@ -48,7 +51,7 @@ def __init__( # if self.provider_name == "google": # self.client = instructor.from_gemini( # client=genai.GenerativeModel( - # model_name=model_name, + # model_name=model_name, # ) # ) if self.provider_name == "groq": @@ -59,7 +62,7 @@ def __init__( else: self.client = openai.Client(**client_config) self.client = instructor.from_openai(self.client, mode=Mode.JSON) - + # Set model name self.model_name = model_name @@ -119,7 +122,7 @@ async def run( "content": f"Understood. I will properly follow the instructions given. Can you provide me with the current page DOM and URL please?", } ) - + # input dom and current page url in a separate message so that the LLM can pay attention to completed tasks better. *based on personal vibe check* if hasattr(input_data, "current_page_dom") and hasattr( input_data, "current_page_url" @@ -131,22 +134,17 @@ async def run( } ) - while True: - # TODO: - # 1. better exeception handling and messages while calling the client - # 2. remove the else block as JSON mode in instrutor won't allow us to pass in tools. - # 3. add a max_turn here to prevent a inifinite fallout + for attempt in range(self.max_retries): try: if len(self.tools_list) == 0: - response = self.client.chat.completions.create( + response = await self.client.chat.completions.create( model=self.model_name, messages=self.messages, response_model=self.output_format, max_retries=3, - max_tokens=1000 if self.provider_name == "anthropic" else None, ) else: - response = self.client.chat.completions.create( + response = await self.client.chat.completions.create( model=self.model_name, messages=self.messages, response_model=self.output_format, @@ -168,17 +166,23 @@ async def run( # continue # parsed_response_content: self.output_format = response_message.parsed - + assert isinstance(response, self.output_format) return response except AssertionError: + logger.error(f"Attempt {attempt + 1} failed: Response type mismatch") + if attempt == self.max_retries - 1: raise TypeError( - f"Expected response_message to be of type {self.output_format.__name__}, but got {type(response).__name__}") + f"Expected response_message to be of type {self.output_format.__name__}, but got {type(response).__name__}" + ) except Exception as e: - logger.error(f"Unexpected error: {str(e)}") - raise + logger.error(f"Attempt {attempt + 1} failed: {str(e)}") + if attempt == self.max_retries - 1: + raise - + raise RuntimeError( + f"Failed to get a valid response after {self.max_retries} attempts" + ) async def _append_tool_response(self, tool_call): function_name = tool_call.function.name diff --git a/sentient/core/orchestrator/orchestrator.py b/sentient/core/orchestrator/orchestrator.py index fecb710..ff5a00d 100644 --- a/sentient/core/orchestrator/orchestrator.py +++ b/sentient/core/orchestrator/orchestrator.py @@ -1,10 +1,11 @@ import asyncio import textwrap import uuid -from typing import Dict, List +from typing import Dict, List, Union from colorama import Fore, init from dotenv import load_dotenv +from pydantic import ValidationError from langsmith import traceable from sentient.core.agent.base import BaseAgent @@ -24,20 +25,23 @@ from sentient.core.skills.open_url import openurl from sentient.core.skills.enter_text_and_click import enter_text_and_click from sentient.core.web_driver.playwright import PlaywrightManager +from sentient.utils.logger import logger init(autoreset=True) class Orchestrator: def __init__( - self, state_to_agent_map: Dict[State, BaseAgent], eval_mode: bool = False + self, state_to_agent_map: Dict[State, BaseAgent], eval_mode: bool = False, model: str = None, max_retries: int = 3 ): load_dotenv() self.state_to_agent_map = state_to_agent_map self.playwright_manager = PlaywrightManager() self.eval_mode = eval_mode self.shutdown_event = asyncio.Event() - # self.session_id = str(uuid.uuid4()) + self.session_id = str(uuid.uuid4()) + self.model = model + self.max_retries = max_retries async def start(self): print("Starting orchestrator") @@ -46,7 +50,7 @@ async def start(self): # if not self.eval_mode: # await self._command_loop() - + @classmethod async def invoke(cls, command: str): orchestrator = cls() @@ -87,11 +91,22 @@ async def execute_command(self, command: str): ) print(f"Executing command {self.memory.objective}") while self.memory.current_state != State.COMPLETED: - await self._handle_state() + try: + await self._handle_state() + except ValidationError as ve: + logger.error(f"Validation error occurred: {ve}") + if attempt == self.max_retries - 1: + raise + continue + except asyncio.TimeoutError: + logger.error("Timeout occurred during state handling") + # Handle timeout, possibly by retrying or moving to the next state + continue self._print_final_response() return self.memory.final_response except Exception as e: - print(f"Error executing the command {self.memory.objective}: {e}") + logger.error(f"Error executing the command {self.memory.objective}: {e}") + raise def run(self) -> Memory: while self.memory.current_state != State.COMPLETED: @@ -105,14 +120,13 @@ async def _handle_state(self): if current_state not in self.state_to_agent_map: raise ValueError(f"Unhandled state! No agent for {current_state}") - + if current_state == State.BASE_AGENT: - await self._handle_agnet() + await self._handle_agent() else: raise ValueError(f"Unhandled state: {current_state}") - - async def _handle_agnet(self): + async def _handle_agent(self): agent = self.state_to_agent_map[State.BASE_AGENT] self._print_memory_and_agent(agent.name) @@ -127,15 +141,25 @@ async def _handle_agnet(self): current_page_dom=str(dom), ) - output: AgentOutput = await agent.run( - input_data - ) - - await self._update_memory_from_agent(output) + max_retries = 3 + for attempt in range(max_retries): + try: + output: AgentOutput = await agent.run( + input_data, session_id=self.session_id, model=self.model + ) + await self._update_memory_from_agent(output) + break + except ValidationError as ve: + logger.error(f"Validation error on attempt {attempt + 1}: {ve}") + if attempt == max_retries - 1: + raise + except asyncio.TimeoutError: + logger.error(f"Timeout on attempt {attempt + 1}") + if attempt == max_retries - 1: + raise print(f"{Fore.MAGENTA}Base Agent Q has updated the memory.") - async def _update_memory_from_agent(self, agentq_output: AgentOutput): if agentq_output.is_complete: self.memory.current_state = State.COMPLETED @@ -166,32 +190,36 @@ async def _update_memory_from_agent(self, agentq_output: AgentOutput): async def handle_agent_actions(self, actions: List[Action]): results = [] for action in actions: - if action.type == ActionType.GOTO_URL: - result = await openurl(url=action.website, timeout=action.timeout or 1) - print("Action - GOTO") - elif action.type == ActionType.TYPE: - entry = EnterTextEntry( - query_selector=f"[mmid='{action.mmid}']", text=action.content - ) - result = await entertext(entry) - print("Action - TYPE") - elif action.type == ActionType.CLICK: - result = await click( - selector=f"[mmid='{action.mmid}']", - wait_before_execution=action.wait_before_execution or 1, - ) - print("Action - CLICK") - elif action.type == ActionType.ENTER_TEXT_AND_CLICK: - result = await enter_text_and_click( - text_selector=f"[mmid='{action.text_element_mmid}']", - text_to_enter=action.text_to_enter, - click_selector=f"[mmid='{action.click_element_mmid}']", - wait_before_click_execution=action.wait_before_click_execution - or 1.5, - ) - print("Action - ENTER TEXT AND CLICK") - else: - result = f"Unsupported action type: {action.type}" + try: + if action.type == ActionType.GOTO_URL: + result = await openurl(url=action.website, timeout=action.timeout or 1) + print("Action - GOTO") + elif action.type == ActionType.TYPE: + entry = EnterTextEntry( + query_selector=f"[mmid='{action.mmid}']", text=action.content + ) + result = await entertext(entry) + print("Action - TYPE") + elif action.type == ActionType.CLICK: + result = await click( + selector=f"[mmid='{action.mmid}']", + wait_before_execution=action.wait_before_execution or 1, + ) + print("Action - CLICK") + elif action.type == ActionType.ENTER_TEXT_AND_CLICK: + result = await enter_text_and_click( + text_selector=f"[mmid='{action.text_element_mmid}']", + text_to_enter=action.text_to_enter, + click_selector=f"[mmid='{action.click_element_mmid}']", + wait_before_click_execution=action.wait_before_click_execution + or 1.5, + ) + print("Action - ENTER TEXT AND CLICK") + else: + result = f"Unsupported action type: {action.type}" + except Exception as e: + logger.error(f"Error executing action {action.type}: {e}") + result = f"Error executing action {action.type}: {str(e)}" results.append(result)