Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ A simple coding agent.
## How it Works
**Source Agent** operates as a stateless entity, guided by clear directives and external context. Its behavior is primarily defined by **`AGENTS.md`**, which serves as the core system prompt.

![](docs/example4.gif)
![](https://github.com/christopherwoodall/source-agent/blob/main/.github/docs/example4.gif?raw=true)

---

Expand Down Expand Up @@ -71,7 +71,7 @@ source-agent \
source-agent --interactive
```

![](docs/example3.gif)
![](https://github.com/christopherwoodall/source-agent/blob/main/.github/docs/example3.gif?raw=true)

---

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ source-agent = "source_agent.entrypoint:main"

[project]
requires-python = ">=3.10"
version = "0.0.15"
version = "0.0.16"
name = "source-agent"
description = "Simple coding agent."
readme = ".github/README.md"
Expand Down
69 changes: 69 additions & 0 deletions tests/test_code_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import json
import pytest
import json
from source_agent.agents.code import CodeAgent


class DummyFunction:
def __init__(self, name, arguments):
self.name = name
self.arguments = arguments


class DummyToolCall:
def __init__(self, function, id):
self.function = function
self.id = id


@pytest.fixture
def agent(monkeypatch):
c = CodeAgent(api_key="key", base_url="url", model="m", temperature=0.1, system_prompt="SP")
c.tool_mapping.clear()
return c


def test_parse_response_message_json(monkeypatch):
monkeypatch.setattr("source_agent.agents.code.openai.OpenAI", lambda *args, **kwargs: None)
json_snip = "{'type':'text','text':'hello world'}"
msg = f"prefix {json_snip} suffix"
c = CodeAgent(system_prompt="SP")
result = c.parse_response_message(msg)
assert result == "hello world"


def test_parse_response_message_plain(monkeypatch):
monkeypatch.setattr("source_agent.agents.code.openai.OpenAI", lambda *args, **kwargs: None)
c = CodeAgent(system_prompt="SP")
assert c.parse_response_message(" foo bar ") == "foo bar"


def test_handle_tool_call_success(agent):
def func(x, y):
return {"sum": x + y}

agent.tool_mapping["foo"] = func
fn = DummyFunction(name="foo", arguments=json.dumps({"x": 2, "y": 3}))
call = DummyToolCall(function=fn, id="tid")
res = agent.handle_tool_call(call)
assert res["role"] == "tool"
assert res["tool_call_id"] == "tid"
assert res["name"] == "foo"
content = json.loads(res["content"])
assert content == {"sum": 5}


def test_handle_tool_call_invalid_json(agent):
fn = DummyFunction(name="foo", arguments="notjson")
call = DummyToolCall(function=fn, id="tid")
res = agent.handle_tool_call(call)
content = json.loads(res["content"])
assert "Invalid JSON arguments" in content["error"]


def test_handle_tool_call_unknown(agent):
fn = DummyFunction(name="bar", arguments=json.dumps({}))
call = DummyToolCall(function=fn, id="tid")
res = agent.handle_tool_call(call)
content = json.loads(res["content"])
assert "Unknown tool" in content["error"]
65 changes: 65 additions & 0 deletions tests/test_e2e_code_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import pytest
from source_agent.agents.code import CodeAgent, AgentEventType


@pytest.fixture(autouse=True)
def patch_openai(monkeypatch):
monkeypatch.setattr("source_agent.agents.code.openai.OpenAI", lambda *args, **kwargs: None)


class DummyFunction:
def __init__(self, name, arguments):
self.name = name
self.arguments = arguments


class DummyToolCall:
def __init__(self, function, id):
self.function = function
self.id = id


class DummyMessage:
def __init__(self, content, tool_calls):
self.content = content
self.tool_calls = tool_calls


class DummyChoice:
def __init__(self, message):
self.message = message


class DummyResponse:
def __init__(self, choice):
self.choices = [choice]


def test_run_stops_on_msg_complete():
agent = CodeAgent(api_key="k", base_url="u", model="m", temperature=0, system_prompt="SP")

def fake_call_llm(messages):
fn = DummyFunction(name="msg_complete_tool", arguments="{}")
tc = DummyToolCall(function=fn, id="1")
msg = DummyMessage(content="", tool_calls=[tc])
return DummyResponse(DummyChoice(msg))

agent.call_llm = fake_call_llm
events = list(agent.run(user_prompt="hi", max_steps=1))
types = [e.type for e in events]
assert AgentEventType.ITERATION_START in types
assert AgentEventType.TOOL_CALL in types
assert AgentEventType.TASK_COMPLETE in types


def test_run_max_steps_reached():
agent = CodeAgent(api_key="k", base_url="u", model="m", temperature=0, system_prompt="SP")

def fake_call_llm(messages):
msg = DummyMessage(content="hello", tool_calls=[])
return DummyResponse(DummyChoice(msg))

agent.call_llm = fake_call_llm
events = list(agent.run(user_prompt="hi", max_steps=2))
assert sum(1 for e in events if e.type == AgentEventType.ITERATION_START) == 2
assert events[-1].type == AgentEventType.MAX_STEPS_REACHED
46 changes: 46 additions & 0 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import math
import re
import datetime
import pytest
from source_agent.tools.calculator_tool import calculate_expression_tool
from source_agent.tools.get_current_date import get_current_date
from source_agent.tools.msg_complete_tool import msg_complete_tool


def test_calculate_expression_simple():
assert calculate_expression_tool("2+3") == 5


def test_calculate_expression_funcs_and_constants():
res = calculate_expression_tool("sqrt(16)+pi")
assert isinstance(res, float)
assert res == math.sqrt(16) + math.pi


@pytest.mark.parametrize(
"expr,err",
[
("2+*3", "Invalid expression syntax"),
("5/0", "Division or modulo by zero"),
("foo(3)", "Unsupported or non-callable function"),
],
)
def test_calculate_expression_errors(expr, err):
out = calculate_expression_tool(expr)
assert "Error:" in out
assert err in out


def test_get_current_date_format():
result = get_current_date()
assert result["success"] is True
dt = datetime.datetime.fromisoformat(result["current_datetime"])
assert isinstance(dt, datetime.datetime)


def test_msg_complete_tool():
result = msg_complete_tool()
assert result["success"] is True
content = result["content"]
assert content["status"] == "completed"
assert re.match(r"\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}", content["timestamp"])
Loading