diff --git a/.github/README.md b/.github/README.md index 87eea63..41d7ec4 100644 --- a/.github/README.md +++ b/.github/README.md @@ -19,7 +19,7 @@ A simple coding agent. ## How it Works **Source Agent** operates as a stateless entity, guided by clear directives and external context. Its behavior is primarily defined by **`AGENTS.md`**, which serves as the core system prompt. -![](docs/example4.gif) +![](https://github.com/christopherwoodall/source-agent/blob/main/.github/docs/example4.gif?raw=true) --- @@ -71,7 +71,7 @@ source-agent \ source-agent --interactive ``` -![](docs/example3.gif) +![](https://github.com/christopherwoodall/source-agent/blob/main/.github/docs/example3.gif?raw=true) --- diff --git a/pyproject.toml b/pyproject.toml index 8455ea4..a90aaa1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ source-agent = "source_agent.entrypoint:main" [project] requires-python = ">=3.10" -version = "0.0.15" +version = "0.0.16" name = "source-agent" description = "Simple coding agent." readme = ".github/README.md" diff --git a/tests/test_code_agent.py b/tests/test_code_agent.py new file mode 100644 index 0000000..29b6da6 --- /dev/null +++ b/tests/test_code_agent.py @@ -0,0 +1,69 @@ +import json +import pytest +import json +from source_agent.agents.code import CodeAgent + + +class DummyFunction: + def __init__(self, name, arguments): + self.name = name + self.arguments = arguments + + +class DummyToolCall: + def __init__(self, function, id): + self.function = function + self.id = id + + +@pytest.fixture +def agent(monkeypatch): + c = CodeAgent(api_key="key", base_url="url", model="m", temperature=0.1, system_prompt="SP") + c.tool_mapping.clear() + return c + + +def test_parse_response_message_json(monkeypatch): + monkeypatch.setattr("source_agent.agents.code.openai.OpenAI", lambda *args, **kwargs: None) + json_snip = "{'type':'text','text':'hello world'}" + msg = f"prefix {json_snip} suffix" + c = CodeAgent(system_prompt="SP") + result = c.parse_response_message(msg) + assert result == "hello world" + + +def test_parse_response_message_plain(monkeypatch): + monkeypatch.setattr("source_agent.agents.code.openai.OpenAI", lambda *args, **kwargs: None) + c = CodeAgent(system_prompt="SP") + assert c.parse_response_message(" foo bar ") == "foo bar" + + +def test_handle_tool_call_success(agent): + def func(x, y): + return {"sum": x + y} + + agent.tool_mapping["foo"] = func + fn = DummyFunction(name="foo", arguments=json.dumps({"x": 2, "y": 3})) + call = DummyToolCall(function=fn, id="tid") + res = agent.handle_tool_call(call) + assert res["role"] == "tool" + assert res["tool_call_id"] == "tid" + assert res["name"] == "foo" + content = json.loads(res["content"]) + assert content == {"sum": 5} + + +def test_handle_tool_call_invalid_json(agent): + fn = DummyFunction(name="foo", arguments="notjson") + call = DummyToolCall(function=fn, id="tid") + res = agent.handle_tool_call(call) + content = json.loads(res["content"]) + assert "Invalid JSON arguments" in content["error"] + + +def test_handle_tool_call_unknown(agent): + fn = DummyFunction(name="bar", arguments=json.dumps({})) + call = DummyToolCall(function=fn, id="tid") + res = agent.handle_tool_call(call) + content = json.loads(res["content"]) + assert "Unknown tool" in content["error"] diff --git a/tests/test_e2e_code_agent.py b/tests/test_e2e_code_agent.py new file mode 100644 index 0000000..4b0130b --- /dev/null +++ b/tests/test_e2e_code_agent.py @@ -0,0 +1,65 @@ +import pytest +from source_agent.agents.code import CodeAgent, AgentEventType + + +@pytest.fixture(autouse=True) +def patch_openai(monkeypatch): + monkeypatch.setattr("source_agent.agents.code.openai.OpenAI", lambda *args, **kwargs: None) + + +class DummyFunction: + def __init__(self, name, arguments): + self.name = name + self.arguments = arguments + + +class DummyToolCall: + def __init__(self, function, id): + self.function = function + self.id = id + + +class DummyMessage: + def __init__(self, content, tool_calls): + self.content = content + self.tool_calls = tool_calls + + +class DummyChoice: + def __init__(self, message): + self.message = message + + +class DummyResponse: + def __init__(self, choice): + self.choices = [choice] + + +def test_run_stops_on_msg_complete(): + agent = CodeAgent(api_key="k", base_url="u", model="m", temperature=0, system_prompt="SP") + + def fake_call_llm(messages): + fn = DummyFunction(name="msg_complete_tool", arguments="{}") + tc = DummyToolCall(function=fn, id="1") + msg = DummyMessage(content="", tool_calls=[tc]) + return DummyResponse(DummyChoice(msg)) + + agent.call_llm = fake_call_llm + events = list(agent.run(user_prompt="hi", max_steps=1)) + types = [e.type for e in events] + assert AgentEventType.ITERATION_START in types + assert AgentEventType.TOOL_CALL in types + assert AgentEventType.TASK_COMPLETE in types + + +def test_run_max_steps_reached(): + agent = CodeAgent(api_key="k", base_url="u", model="m", temperature=0, system_prompt="SP") + + def fake_call_llm(messages): + msg = DummyMessage(content="hello", tool_calls=[]) + return DummyResponse(DummyChoice(msg)) + + agent.call_llm = fake_call_llm + events = list(agent.run(user_prompt="hi", max_steps=2)) + assert sum(1 for e in events if e.type == AgentEventType.ITERATION_START) == 2 + assert events[-1].type == AgentEventType.MAX_STEPS_REACHED diff --git a/tests/test_tools.py b/tests/test_tools.py new file mode 100644 index 0000000..cbeb95d --- /dev/null +++ b/tests/test_tools.py @@ -0,0 +1,46 @@ +import math +import re +import datetime +import pytest +from source_agent.tools.calculator_tool import calculate_expression_tool +from source_agent.tools.get_current_date import get_current_date +from source_agent.tools.msg_complete_tool import msg_complete_tool + + +def test_calculate_expression_simple(): + assert calculate_expression_tool("2+3") == 5 + + +def test_calculate_expression_funcs_and_constants(): + res = calculate_expression_tool("sqrt(16)+pi") + assert isinstance(res, float) + assert res == math.sqrt(16) + math.pi + + +@pytest.mark.parametrize( + "expr,err", + [ + ("2+*3", "Invalid expression syntax"), + ("5/0", "Division or modulo by zero"), + ("foo(3)", "Unsupported or non-callable function"), + ], +) +def test_calculate_expression_errors(expr, err): + out = calculate_expression_tool(expr) + assert "Error:" in out + assert err in out + + +def test_get_current_date_format(): + result = get_current_date() + assert result["success"] is True + dt = datetime.datetime.fromisoformat(result["current_datetime"]) + assert isinstance(dt, datetime.datetime) + + +def test_msg_complete_tool(): + result = msg_complete_tool() + assert result["success"] is True + content = result["content"] + assert content["status"] == "completed" + assert re.match(r"\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}", content["timestamp"]) diff --git a/tests/test_tools_extra.py b/tests/test_tools_extra.py new file mode 100644 index 0000000..4a5796d --- /dev/null +++ b/tests/test_tools_extra.py @@ -0,0 +1,224 @@ +import os +import subprocess +import pathlib +import pytest +import tempfile +import shutil +import re +from source_agent.tools.file_list_tool import file_list_tool, load_gitignore +from source_agent.tools.file_read_tool import file_read_tool +from source_agent.tools.file_write_tool import file_write_tool +from source_agent.tools.file_search_tool import file_search_tool, build_plain_text_matcher, is_subpath +from source_agent.tools.execute_shell_command_tool import execute_shell_command +from source_agent.tools.run_pytest_tests_tool import run_pytest_tests +from source_agent.tools.directory_create_tool import directory_create_tool +from source_agent.tools.directory_delete_tool import directory_delete_tool +from source_agent.tools.file_delete_tool import file_delete_tool +from source_agent.tools.tool_registry import ToolRegistry +from source_agent.tools.web_search_tool import web_search_tool + +# file_list_tool + + +def test_file_list_not_found(tmp_path): + os.chdir(tmp_path) + res = file_list_tool(path="nofile") + assert not res["success"] and "Path not found" in res["error"] + + +def test_file_list_not_dir(tmp_path): + os.chdir(tmp_path) + f = tmp_path / "f.txt" + f.write_text("x") + res = file_list_tool(path="f.txt") + assert not res["success"] and "Not a directory" in res["error"] + + +def test_file_list_non_recursive(tmp_path): + os.chdir(tmp_path) + (tmp_path / "a").mkdir() + (tmp_path / "b.txt").write_text("x") + res = file_list_tool(path=".", recursive=False) + assert res["success"] + assert "b.txt" in res["files"] and "a/" in res["files"] + + +def test_file_list_recursive(tmp_path): + os.chdir(tmp_path) + (tmp_path / "d").mkdir() + (tmp_path / "d" / "c.txt").write_text("y") + res = file_list_tool(path=".", recursive=True) + assert res["success"] and "d/c.txt" in res["files"] + + +# file_read_tool + + +def test_file_read_success(tmp_path): + os.chdir(tmp_path) + f = tmp_path / "r.txt" + f.write_text("hello") + res = file_read_tool("r.txt") + assert res["success"] and res["content"] == "hello" + + +def test_file_read_traversal(tmp_path): + os.chdir(tmp_path) + res = file_read_tool("../etc/passwd") + assert not res["success"] and "Path traversal" in res["error"] + + +def test_file_read_not_found(tmp_path): + os.chdir(tmp_path) + res = file_read_tool("nofile") + assert not res["success"] and "File not found" in res["error"] + + +# file_write_tool + + +def test_file_write_success(tmp_path): + os.chdir(tmp_path) + res = file_write_tool("w.txt", "data") + assert res["success"] and (tmp_path / "w.txt").read_text() == "data" + + +def test_file_write_traversal(tmp_path): + os.chdir(tmp_path) + res = file_write_tool("../w.txt", "x") + assert not res["success"] and "Path traversal" in res["error"] + + +# file_search_tool + + +def test_file_search_name_only(tmp_path): + os.chdir(tmp_path) + f = tmp_path / "f1.txt" + f.write_text("X\nY\nZ") + res = file_search_tool(name="*.txt") + assert res["success"] and any("f1.txt" in item for item in res["content"]) + + +def test_file_search_with_pattern(tmp_path): + os.chdir(tmp_path) + f = tmp_path / "f2.txt" + f.write_text("hello world\nabc hello") + res = file_search_tool(name="f2.txt", pattern="hello", ignore_case=False) + assert res["success"] and any(":1:" in item for item in res["content"]) + + +# helper tests + + +def test_build_plain_text_matcher(): + m = build_plain_text_matcher("AbC", ignore_case=True) + assert m("xxabcxx") + m2 = build_plain_text_matcher("AbC", ignore_case=False) + assert not m2("xxabcxx") + + +def test_is_subpath(tmp_path): + base = tmp_path + sub = tmp_path / "s" + sub.mkdir() + assert is_subpath(sub, base) + outside = pathlib.Path(tmp_path.parent) + assert not is_subpath(outside, base) + + +# execute_shell_command + + +def test_execute_shell_command_success(): + res = execute_shell_command("echo hi") + assert res["success"] and res["stdout"] == "hi" + + +# run_pytest_tests + + +def test_run_pytest_tests_traversal(tmp_path): + os.chdir(tmp_path) + res = run_pytest_tests(target_paths=[str(tmp_path.parent)]) + assert not res["success"] and "Path traversal" in res["error"] + + +def test_run_pytest_tests_no_tests(tmp_path): + os.chdir(tmp_path) + res = run_pytest_tests() + assert res["exit_code"] != 0 + + +# directory_create_tool + + +def test_directory_create_tool_success(tmp_path): + os.chdir(tmp_path) + res = directory_create_tool(path="d1") + assert res["success"] and (tmp_path / "d1").exists() + + +def test_directory_create_tool_traversal(tmp_path): + os.chdir(tmp_path) + res = directory_create_tool(path="../d") + assert not res["success"] and "Path traversal" in res["error"] + + +# directory_delete_tool + + +def test_directory_delete_tool_success(tmp_path): + os.chdir(tmp_path) + d = tmp_path / "dd" + d.mkdir() + res = directory_delete_tool(path="dd") + assert res["success"] and not d.exists() + + +def test_directory_delete_tool_recursive(tmp_path): + os.chdir(tmp_path) + d = tmp_path / "rr" + (d / "x").mkdir(parents=True) + res = directory_delete_tool(path="rr", recursive=True) + assert res["success"] and not d.exists() + + +# file_delete_tool + + +def test_file_delete_tool_success(tmp_path): + os.chdir(tmp_path) + f = tmp_path / "fx.txt" + f.write_text("hi") + res = file_delete_tool(path="fx.txt") + assert res["success"] and not f.exists() + + +# tool_registry + + +def test_tool_registry_register(): + tr = ToolRegistry() + + @tr.register(name="n", description="d", parameters={}) + def fn(): + return True + + tools = tr.get_tools() + mapping = tr.get_mapping() + assert any(t["function"]["name"] == "n" for t in tools) + assert mapping["n"] is fn + + +# web_search_tool error + + +def test_web_search_tool_failure(monkeypatch): + class BadDDGS: + def text(self, query, max_results): + raise RuntimeError("fail") + + monkeypatch.setattr("source_agent.tools.web_search_tool.DDGS", BadDDGS) + res = web_search_tool(query="x") + assert not res["success"] and "Search failed" in res["content"][0]