Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions sentient/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,30 @@
class Sentient:
def __init__(self):
self.orchestrator = None
def _create_state_to_agent_map(self, provider: str, model: str, custom_base_url: str = None):

def _create_state_to_agent_map(self, provider: str, model: str, custom_base_url: str, max_retries: int):
provider_instance = get_provider(provider, custom_base_url)
return {
State.BASE_AGENT: Agent(provider=provider_instance, model_name=model),
State.BASE_AGENT: Agent(provider=provider_instance, model_name=model, max_retries=max_retries),
}

async def _initialize(self, provider: str, model: str, custom_base_url: str = None):
async def _initialize(self, provider: str, model: str, custom_base_url: str, max_retries: int):
if not self.orchestrator:
state_to_agent_map = self._create_state_to_agent_map(provider, model, custom_base_url)
self.orchestrator = Orchestrator(state_to_agent_map=state_to_agent_map)
state_to_agent_map = self._create_state_to_agent_map(provider, model, custom_base_url, max_retries)
self.orchestrator = Orchestrator(state_to_agent_map=state_to_agent_map, model=model)
await self.orchestrator.start()

async def invoke(
self,
goal: str, provider: str = "openai",
model: str = "gpt-4o-2024-08-06",
task_instructions: str = None,
custom_base_url: str = None
):
self,
max_retries: int,
goal: str, provider: str = "openai",
model: str = "gpt-4o-2024-08-06",
task_instructions: str = None,
custom_base_url: str = None,
):
if task_instructions:
ltm.set_task_instructions(task_instructions)
await self._initialize(provider, model, custom_base_url)
await self._initialize(provider, model, custom_base_url, max_retries)
result = await self.orchestrator.execute_command(goal)
return result

Expand Down
2 changes: 1 addition & 1 deletion sentient/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
async def main():
# Define state machine
state_to_agent_map = {
State.BASE_AGENT: Agent(),
State.BASE_AGENT: Agent(max_retries=3),
}

orchestrator = Orchestrator(state_to_agent_map=state_to_agent_map)
Expand Down
3 changes: 2 additions & 1 deletion sentient/core/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


class Agent(BaseAgent):
def __init__(self, provider:LLMProvider, model_name: str):
def __init__(self, provider: LLMProvider, model_name: str, max_retries: int):
self.name = "sentient"
self.ltm = None
self.ltm = self.__get_ltm()
Expand All @@ -22,6 +22,7 @@ def __init__(self, provider:LLMProvider, model_name: str):
keep_message_history=False,
provider=provider,
model_name=model_name,
max_retries=max_retries,
)

@staticmethod
Expand Down
38 changes: 21 additions & 17 deletions sentient/core/agent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from sentient.utils.providers import get_provider, LLMProvider

class BaseAgent:

def __init__(
self,
name: str,
Expand All @@ -26,9 +27,11 @@ def __init__(
keep_message_history: bool = True,
provider: LLMProvider = None,
model_name: str = None,
max_retries: int = 3
):
# Metdata
# Metadata
self.agent_name = name
self.max_retries = max_retries

# Messages
self.system_prompt = system_prompt
Expand All @@ -48,7 +51,7 @@ def __init__(
# if self.provider_name == "google":
# self.client = instructor.from_gemini(
# client=genai.GenerativeModel(
# model_name=model_name,
# model_name=model_name,
# )
# )
if self.provider_name == "groq":
Expand All @@ -59,7 +62,7 @@ def __init__(
else:
self.client = openai.Client(**client_config)
self.client = instructor.from_openai(self.client, mode=Mode.JSON)

# Set model name
self.model_name = model_name

Expand Down Expand Up @@ -119,7 +122,7 @@ async def run(
"content": f"Understood. I will properly follow the instructions given. Can you provide me with the current page DOM and URL please?",
}
)

# input dom and current page url in a separate message so that the LLM can pay attention to completed tasks better. *based on personal vibe check*
if hasattr(input_data, "current_page_dom") and hasattr(
input_data, "current_page_url"
Expand All @@ -131,22 +134,17 @@ async def run(
}
)

while True:
# TODO:
# 1. better exeception handling and messages while calling the client
# 2. remove the else block as JSON mode in instrutor won't allow us to pass in tools.
# 3. add a max_turn here to prevent a inifinite fallout
for attempt in range(self.max_retries):
try:
if len(self.tools_list) == 0:
response = self.client.chat.completions.create(
response = await self.client.chat.completions.create(
model=self.model_name,
messages=self.messages,
response_model=self.output_format,
max_retries=3,
max_tokens=1000 if self.provider_name == "anthropic" else None,
)
else:
response = self.client.chat.completions.create(
response = await self.client.chat.completions.create(
model=self.model_name,
messages=self.messages,
response_model=self.output_format,
Expand All @@ -168,17 +166,23 @@ async def run(
# continue

# parsed_response_content: self.output_format = response_message.parsed

assert isinstance(response, self.output_format)
return response
except AssertionError:
logger.error(f"Attempt {attempt + 1} failed: Response type mismatch")
if attempt == self.max_retries - 1:
raise TypeError(
f"Expected response_message to be of type {self.output_format.__name__}, but got {type(response).__name__}")
f"Expected response_message to be of type {self.output_format.__name__}, but got {type(response).__name__}"
)
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
raise
logger.error(f"Attempt {attempt + 1} failed: {str(e)}")
if attempt == self.max_retries - 1:
raise


raise RuntimeError(
f"Failed to get a valid response after {self.max_retries} attempts"
)

async def _append_tool_response(self, tool_call):
function_name = tool_call.function.name
Expand Down
112 changes: 70 additions & 42 deletions sentient/core/orchestrator/orchestrator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import asyncio
import textwrap
import uuid
from typing import Dict, List
from typing import Dict, List, Union

from colorama import Fore, init
from dotenv import load_dotenv
from pydantic import ValidationError
from langsmith import traceable

from sentient.core.agent.base import BaseAgent
Expand All @@ -24,20 +25,23 @@
from sentient.core.skills.open_url import openurl
from sentient.core.skills.enter_text_and_click import enter_text_and_click
from sentient.core.web_driver.playwright import PlaywrightManager
from sentient.utils.logger import logger

init(autoreset=True)


class Orchestrator:
def __init__(
self, state_to_agent_map: Dict[State, BaseAgent], eval_mode: bool = False
self, state_to_agent_map: Dict[State, BaseAgent], eval_mode: bool = False, model: str = None, max_retries: int = 3
):
load_dotenv()
self.state_to_agent_map = state_to_agent_map
self.playwright_manager = PlaywrightManager()
self.eval_mode = eval_mode
self.shutdown_event = asyncio.Event()
# self.session_id = str(uuid.uuid4())
self.session_id = str(uuid.uuid4())
self.model = model
self.max_retries = max_retries

async def start(self):
print("Starting orchestrator")
Expand All @@ -46,7 +50,7 @@ async def start(self):

# if not self.eval_mode:
# await self._command_loop()

@classmethod
async def invoke(cls, command: str):
orchestrator = cls()
Expand Down Expand Up @@ -87,11 +91,22 @@ async def execute_command(self, command: str):
)
print(f"Executing command {self.memory.objective}")
while self.memory.current_state != State.COMPLETED:
await self._handle_state()
try:
await self._handle_state()
except ValidationError as ve:
logger.error(f"Validation error occurred: {ve}")
if attempt == self.max_retries - 1:
raise
continue
except asyncio.TimeoutError:
logger.error("Timeout occurred during state handling")
# Handle timeout, possibly by retrying or moving to the next state
continue
self._print_final_response()
return self.memory.final_response
except Exception as e:
print(f"Error executing the command {self.memory.objective}: {e}")
logger.error(f"Error executing the command {self.memory.objective}: {e}")
raise

def run(self) -> Memory:
while self.memory.current_state != State.COMPLETED:
Expand All @@ -105,14 +120,13 @@ async def _handle_state(self):

if current_state not in self.state_to_agent_map:
raise ValueError(f"Unhandled state! No agent for {current_state}")

if current_state == State.BASE_AGENT:
await self._handle_agnet()
await self._handle_agent()
else:
raise ValueError(f"Unhandled state: {current_state}")


async def _handle_agnet(self):
async def _handle_agent(self):
agent = self.state_to_agent_map[State.BASE_AGENT]
self._print_memory_and_agent(agent.name)

Expand All @@ -127,15 +141,25 @@ async def _handle_agnet(self):
current_page_dom=str(dom),
)

output: AgentOutput = await agent.run(
input_data
)

await self._update_memory_from_agent(output)
max_retries = 3
for attempt in range(max_retries):
try:
output: AgentOutput = await agent.run(
input_data, session_id=self.session_id, model=self.model
)
await self._update_memory_from_agent(output)
break
except ValidationError as ve:
logger.error(f"Validation error on attempt {attempt + 1}: {ve}")
if attempt == max_retries - 1:
raise
except asyncio.TimeoutError:
logger.error(f"Timeout on attempt {attempt + 1}")
if attempt == max_retries - 1:
raise

print(f"{Fore.MAGENTA}Base Agent Q has updated the memory.")


async def _update_memory_from_agent(self, agentq_output: AgentOutput):
if agentq_output.is_complete:
self.memory.current_state = State.COMPLETED
Expand Down Expand Up @@ -166,32 +190,36 @@ async def _update_memory_from_agent(self, agentq_output: AgentOutput):
async def handle_agent_actions(self, actions: List[Action]):
results = []
for action in actions:
if action.type == ActionType.GOTO_URL:
result = await openurl(url=action.website, timeout=action.timeout or 1)
print("Action - GOTO")
elif action.type == ActionType.TYPE:
entry = EnterTextEntry(
query_selector=f"[mmid='{action.mmid}']", text=action.content
)
result = await entertext(entry)
print("Action - TYPE")
elif action.type == ActionType.CLICK:
result = await click(
selector=f"[mmid='{action.mmid}']",
wait_before_execution=action.wait_before_execution or 1,
)
print("Action - CLICK")
elif action.type == ActionType.ENTER_TEXT_AND_CLICK:
result = await enter_text_and_click(
text_selector=f"[mmid='{action.text_element_mmid}']",
text_to_enter=action.text_to_enter,
click_selector=f"[mmid='{action.click_element_mmid}']",
wait_before_click_execution=action.wait_before_click_execution
or 1.5,
)
print("Action - ENTER TEXT AND CLICK")
else:
result = f"Unsupported action type: {action.type}"
try:
if action.type == ActionType.GOTO_URL:
result = await openurl(url=action.website, timeout=action.timeout or 1)
print("Action - GOTO")
elif action.type == ActionType.TYPE:
entry = EnterTextEntry(
query_selector=f"[mmid='{action.mmid}']", text=action.content
)
result = await entertext(entry)
print("Action - TYPE")
elif action.type == ActionType.CLICK:
result = await click(
selector=f"[mmid='{action.mmid}']",
wait_before_execution=action.wait_before_execution or 1,
)
print("Action - CLICK")
elif action.type == ActionType.ENTER_TEXT_AND_CLICK:
result = await enter_text_and_click(
text_selector=f"[mmid='{action.text_element_mmid}']",
text_to_enter=action.text_to_enter,
click_selector=f"[mmid='{action.click_element_mmid}']",
wait_before_click_execution=action.wait_before_click_execution
or 1.5,
)
print("Action - ENTER TEXT AND CLICK")
else:
result = f"Unsupported action type: {action.type}"
except Exception as e:
logger.error(f"Error executing action {action.type}: {e}")
result = f"Error executing action {action.type}: {str(e)}"

results.append(result)

Expand Down