Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/any_agent/frameworks/agno.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,10 +278,15 @@ async def _load_agent(self) -> None:
**agent_args,
)

async def _run_async(self, prompt: str, **kwargs: Any) -> str | BaseModel:
async def _run_async(
self, prompt: str | list[dict[str, Any]], **kwargs: Any
) -> str | BaseModel:
if not self._agent:
error_message = "Agent not loaded. Call load_agent() first."
raise ValueError(error_message)
if not isinstance(prompt, str):
msg = "Agno framework does not support list of messages as input. Use a plain string prompt."
raise NotImplementedError(msg)
result: RunResponse = await self._agent.arun(prompt, **kwargs)
return result.content # type: ignore[return-value]

Expand Down
18 changes: 14 additions & 4 deletions src/any_agent/frameworks/any_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,17 +314,25 @@ async def __aexit__(
"""Exit the async context manager and clean up resources."""
await self.cleanup_async()

def run(self, prompt: str, **kwargs: Any) -> AgentTrace:
def run(self, prompt: str | list[dict[str, Any]], **kwargs: Any) -> AgentTrace:
"""Run the agent with the given prompt."""
return run_async_in_sync(
self.run_async(prompt, **kwargs), allow_running_loop=INSIDE_NOTEBOOK
)

async def run_async(self, prompt: str, **kwargs: Any) -> AgentTrace:
async def run_async(
self, prompt: str | list[dict[str, Any]], **kwargs: Any
) -> AgentTrace:
"""Run the agent asynchronously with the given prompt.

Args:
prompt: The user prompt to be passed to the agent.
prompt: The user prompt to be passed to the agent. Can be a plain
string or a list of message dicts (e.g.
``[{"role": "user", "content": "hello"}]``) following the
OpenAI chat-completion message format. When a list is provided
it is forwarded directly to the underlying LLM, giving callers
full control over the conversation structure.
Note: passing a list is only supported by the ``TINYAGENT`` framework.

kwargs: Will be passed to the underlying runner used
by the framework.
Expand Down Expand Up @@ -530,7 +538,9 @@ async def _load_agent(self) -> None:
"""Load the agent instance."""

@abstractmethod
async def _run_async(self, prompt: str, **kwargs: Any) -> str | BaseModel:
async def _run_async(
self, prompt: str | list[dict[str, Any]], **kwargs: Any
) -> str | BaseModel:
"""To be implemented by each framework."""

@abstractmethod
Expand Down
5 changes: 4 additions & 1 deletion src/any_agent/frameworks/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,14 +351,17 @@ async def _load_agent(self) -> None:

async def _run_async( # type: ignore[no-untyped-def]
self,
prompt: str,
prompt: str | list[dict[str, Any]],
user_id: str | None = None,
session_id: str | None = None,
**kwargs,
) -> str | BaseModel:
if not self._agent:
error_message = "Agent not loaded. Call load_agent() first."
raise ValueError(error_message)
if not isinstance(prompt, str):
msg = "Google framework does not support list of messages as input. Use a plain string prompt."
raise NotImplementedError(msg)
runner = InMemoryRunner(self._agent)
user_id = user_id or str(uuid4())
session_id = session_id or str(uuid4())
Expand Down
7 changes: 6 additions & 1 deletion src/any_agent/frameworks/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,10 +288,15 @@ async def _load_agent(self) -> None:
**agent_args,
)

async def _run_async(self, prompt: str, **kwargs: Any) -> str | BaseModel:
async def _run_async(
self, prompt: str | list[dict[str, Any]], **kwargs: Any
) -> str | BaseModel:
if not self._agent:
error_message = "Agent not loaded. Call load_agent() first."
raise ValueError(error_message)
if not isinstance(prompt, str):
msg = "LangChain framework does not support list of messages as input. Use a plain string prompt."
raise NotImplementedError(msg)
inputs = {"messages": [("user", prompt)]}
result = await self._agent.ainvoke(inputs, **kwargs)

Expand Down
7 changes: 6 additions & 1 deletion src/any_agent/frameworks/llama_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,10 +563,15 @@ async def _load_agent(self) -> None:
**self.config.agent_args or {},
)

async def _run_async(self, prompt: str, **kwargs: Any) -> str | BaseModel:
async def _run_async(
self, prompt: str | list[dict[str, Any]], **kwargs: Any
) -> str | BaseModel:
if not self._agent:
error_message = "Agent not loaded. Call load_agent() first."
raise ValueError(error_message)
if not isinstance(prompt, str):
msg = "LlamaIndex framework does not support list of messages as input. Use a plain string prompt."
raise NotImplementedError(msg)
result: AgentOutput = await self._agent.run(prompt, **kwargs)
# assert that it's a TextBlock
if not result.response.blocks or not hasattr(result.response.blocks[0], "text"):
Expand Down
7 changes: 6 additions & 1 deletion src/any_agent/frameworks/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,10 +449,15 @@ def _filter_mcp_tools(self, tools: list[Any], mcp_clients: list[Any]) -> list[An
# The OpenAI framework can handle them as regular tools.
return tools

async def _run_async(self, prompt: str, **kwargs: Any) -> str | BaseModel:
async def _run_async(
self, prompt: str | list[dict[str, Any]], **kwargs: Any
) -> str | BaseModel:
if not self._agent:
error_message = "Agent not loaded. Call load_agent() first."
raise ValueError(error_message)
if not isinstance(prompt, str):
msg = "OpenAI framework does not support list of messages as input. Use a plain string prompt."
raise NotImplementedError(msg)
if not kwargs.get("max_turns"):
kwargs["max_turns"] = math.inf
result = await Runner.run(self._agent, prompt, **kwargs)
Expand Down
7 changes: 6 additions & 1 deletion src/any_agent/frameworks/smolagents.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,10 +362,15 @@ async def _load_agent(self) -> None:

assert self._agent

async def _run_async(self, prompt: str, **kwargs: Any) -> str | BaseModel:
async def _run_async(
self, prompt: str | list[dict[str, Any]], **kwargs: Any
) -> str | BaseModel:
if not self._agent:
error_message = "Agent not loaded. Call load_agent() first."
raise ValueError(error_message)
if not isinstance(prompt, str):
msg = "Smolagents framework does not support list of messages as input. Use a plain string prompt."
raise NotImplementedError(msg)
result = self._agent.run(prompt, **kwargs)
if self.config.output_type:
result_json = (
Expand Down
27 changes: 16 additions & 11 deletions src/any_agent/frameworks/tinyagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,22 +171,27 @@ async def _load_agent(self) -> None:
self.completion_params["tools"].append(function_def)
self.clients[tool_name] = ToolExecutor(tool)

async def _run_async(self, prompt: str, **kwargs: Any) -> str | BaseModel:
async def _run_async(
self, prompt: str | list[dict[str, Any]], **kwargs: Any
) -> str | BaseModel:
if self.uses_openai:
self.completion_params["tool_choice"] = "auto"
if self.config.output_type:
self.completion_params["response_format"] = self.config.output_type

messages = [
{
"role": "system",
"content": self.config.instructions or DEFAULT_SYSTEM_PROMPT,
},
{
"role": "user",
"content": prompt,
},
]
if isinstance(prompt, list):
messages = prompt
else:
messages = [
{
"role": "system",
"content": self.config.instructions or DEFAULT_SYSTEM_PROMPT,
},
{
"role": "user",
"content": prompt,
},
]

if kwargs.pop("max_turns", None):
logger.warning(
Expand Down
17 changes: 17 additions & 0 deletions tests/integration/frameworks/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,20 @@ def write_file(text: str) -> None:
html_output = console.export_html(inline_styles=True)
with open(f"{trace_path}_trace.html", "w", encoding="utf-8") as f:
f.write(html_output.replace("<!DOCTYPE html>", ""))


def test_tinyagent_run_with_messages() -> None:
"""Test that TinyAgent.run accepts a list of message dicts instead of a plain string."""
agent = AnyAgent.create(
AgentFramework.TINYAGENT,
AgentConfig(model_id=DEFAULT_SMALL_MODEL_ID),
)

messages = [
{"role": "system", "content": "You are a helpful assistant. Be concise."},
{"role": "user", "content": "What is 2 + 2? Reply with just the number."},
]

agent_trace = agent.run(messages)
assert agent_trace.final_output
assert "4" in str(agent_trace.final_output)