Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/any_agent/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,13 @@ class CalendarEvent(BaseModel):
Refer to [any-llm Completion API Docs](https://mozilla-ai.github.io/any-llm/api/completion/) for more info.
"""

any_llm_args: MutableMapping[str, Any] | None = None
"""Pass arguments to `AnyLLM.create()` when using integrations backed by any-llm.

Use this for provider/client initialization options that are not completion-time
generation params (which should be passed via `model_args`).
"""

output_type: type[BaseModel] | None = None
"""Control the output schema from calling `run`. By default, the agent will return a type str.

Expand Down
25 changes: 17 additions & 8 deletions src/any_agent/frameworks/llama_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(
max_retries: int = 10,
api_key: str | None = None,
api_base: str | None = None,
any_llm_args: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
additional_kwargs = additional_kwargs or {}
Expand All @@ -116,10 +117,14 @@ def __init__(
)

self._parse_model(model)
llm_create_kwargs: dict[str, Any] = dict(any_llm_args or {})
if api_key is not None:
llm_create_kwargs["api_key"] = api_key
if api_base is not None:
llm_create_kwargs["api_base"] = api_base
self._client = AnyLLM.create(
provider=self._provider,
api_key=api_key,
api_base=api_base,
**llm_create_kwargs,
)

def _parse_model(self, model: str) -> None:
Expand Down Expand Up @@ -512,14 +517,18 @@ def _get_model(self, agent_config: AgentConfig) -> "LLM":

model_id = agent_config.model_id

model_kwargs: dict[str, Any] = {
"model": model_id,
"api_key": agent_config.api_key,
"api_base": agent_config.api_base,
"additional_kwargs": additional_kwargs,
}
if model_type is DEFAULT_MODEL_TYPE and agent_config.any_llm_args is not None:
model_kwargs["any_llm_args"] = agent_config.any_llm_args

return cast(
"LLM",
model_type(
model=model_id,
api_key=agent_config.api_key,
api_base=agent_config.api_base,
additional_kwargs=additional_kwargs, # type: ignore[arg-type]
),
model_type(**model_kwargs),
)

async def _load_agent(self) -> None:
Expand Down
22 changes: 16 additions & 6 deletions src/any_agent/frameworks/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,19 @@ def __init__(
model: str,
base_url: str | None = None,
api_key: str | None = None,
any_llm_args: dict[str, Any] | None = None,
):
provider, model_id = AnyLLM.split_model_provider(model)
self.model = model
self.base_url = base_url
self.api_key = api_key
self.llm = AnyLLM.create(provider=provider, api_key=api_key, api_base=base_url)
llm_create_kwargs: dict[str, Any] = dict(any_llm_args or {})
if api_key is not None:
llm_create_kwargs["api_key"] = api_key
if base_url is not None:
llm_create_kwargs["api_base"] = base_url

self.llm = AnyLLM.create(provider=provider, **llm_create_kwargs)
self.model_id = model_id

async def get_response(
Expand Down Expand Up @@ -399,11 +406,14 @@ def _get_model(
base_url = agent_config.api_base or cast(
"str | None", model_args.get("api_base")
)
return model_type(
model=agent_config.model_id,
base_url=base_url,
api_key=agent_config.api_key,
)
model_kwargs: dict[str, Any] = {
"model": agent_config.model_id,
"base_url": base_url,
"api_key": agent_config.api_key,
}
if model_type is DEFAULT_MODEL_TYPE and agent_config.any_llm_args is not None:
model_kwargs["any_llm_args"] = agent_config.any_llm_args
return model_type(**model_kwargs)

async def _load_agent(self) -> None:
"""Load the OpenAI agent with the given configuration."""
Expand Down
2 changes: 1 addition & 1 deletion src/any_agent/frameworks/tinyagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(self, config: AgentConfig) -> None:
self.uses_openai = provider_name == LLMProvider.OPENAI

# Create the LLM instance using the AnyLLM class pattern
llm_kwargs: dict[str, Any] = {}
llm_kwargs: dict[str, Any] = dict(self.config.any_llm_args or {})
if self.config.api_key:
llm_kwargs["api_key"] = self.config.api_key
if self.config.api_base:
Expand Down
31 changes: 31 additions & 0 deletions tests/unit/frameworks/test_llama_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,34 @@ def test_run_llama_index_agent_custom_args() -> None:
)
agent.run("foo", timeout=10)
agent_mock.run.assert_called_once_with("foo", timeout=10)


def test_load_llama_index_agent_forwards_any_llm_args() -> None:
model_mock = MagicMock()
create_mock = MagicMock()
create_mock.return_value = MagicMock()
any_llm_args = {"timeout": 17, "max_retries": 3}

from llama_index.core.tools import FunctionTool

with (
patch("any_agent.frameworks.llama_index.DEFAULT_AGENT_TYPE", create_mock),
patch("any_agent.frameworks.llama_index.DEFAULT_MODEL_TYPE", model_mock),
patch.object(FunctionTool, "from_defaults"),
):
AnyAgent.create(
AgentFramework.LLAMA_INDEX,
AgentConfig(
model_id="gemini/gemini-2.0-flash",
instructions="You are a helpful assistant",
any_llm_args=any_llm_args,
),
)

model_mock.assert_called_once_with(
model="gemini/gemini-2.0-flash",
api_key=None,
api_base=None,
additional_kwargs={},
any_llm_args=any_llm_args,
)
25 changes: 25 additions & 0 deletions tests/unit/frameworks/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,31 @@ def test_openai_with_api_key() -> None:
)


def test_openai_forwards_any_llm_args() -> None:
mock_agent = MagicMock()
mock_model = MagicMock()
any_llm_args = {"timeout": 42, "headers": {"x-test": "1"}}

with (
patch("any_agent.frameworks.openai.Agent", mock_agent),
patch("any_agent.frameworks.openai.DEFAULT_MODEL_TYPE", mock_model),
):
AnyAgent.create(
AgentFramework.OPENAI,
AgentConfig(
model_id="mistral:mistral-small-latest",
any_llm_args=any_llm_args,
),
)

mock_model.assert_called_once_with(
model="mistral:mistral-small-latest",
base_url=None,
api_key=None,
any_llm_args=any_llm_args,
)


def test_load_openai_with_mcp_server() -> None:
mock_agent = MagicMock()
mock_function_tool = MagicMock()
Expand Down
15 changes: 15 additions & 0 deletions tests/unit/frameworks/test_tinyagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,21 @@ def test_uses_openai_handles_gateway_provider(
assert agent.uses_openai is expected_uses_openai


def test_tinyagent_forwards_any_llm_args_to_anyllm_create() -> None:
any_llm_args = {"timeout": 99, "organization": "test-org"}
provider, _ = AnyLLM.split_model_provider(DEFAULT_SMALL_MODEL_ID)

with patch("any_agent.frameworks.tinyagent.AnyLLM.create") as mock_create:
TinyAgent(
AgentConfig(
model_id=DEFAULT_SMALL_MODEL_ID,
any_llm_args=any_llm_args,
)
)

mock_create.assert_called_once_with(provider, **any_llm_args)


@pytest.mark.asyncio
async def test_tool_result_appended_when_tool_not_found() -> None:
"""Test that tool_result message is appended when a tool is not found.
Expand Down