Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 51 additions & 4 deletions src/praga_core/agents/react_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from openai.types.chat import ChatCompletionMessageParam

from praga_core.retriever import RetrieverAgentBase
from praga_core.types import PageReference
from praga_core.types import Page, PageReference

from .format_instructions import get_agent_format_instructions
from .response import (
Expand Down Expand Up @@ -150,6 +150,9 @@ def __init__(
# Configure logging
self._configure_logging()

# Track pages accessed during search
self._accessed_pages: Dict[str, Page] = {}

async def search(self, query: str) -> List[PageReference]:
"""
Execute a search using the ReAct agent's approach.
Expand All @@ -164,6 +167,9 @@ async def search(self, query: str) -> List[PageReference]:
logger.info("Starting RetrieverAgent search")
logger.info("Query: %s", query)

# Clear accessed pages for new search
self._accessed_pages.clear()

try:
return await self._run_agent(query)
except Exception as e:
Expand Down Expand Up @@ -288,13 +294,44 @@ def _handle_agent_finish(self, agent_finish: AgentFinish) -> List[PageReference]
if agent_response.response_code != ResponseCode.SUCCESS:
logger.error("Agent response code: %s", agent_response.response_code)
return []

# Resolve references using tracked pages
resolved_references = self._resolve_references_internally(
agent_response.references
)

logger.info("=" * 80)
logger.info(
f"Search completed: Found {len(agent_response.references)} document references"
f"Search completed: Found {len(resolved_references)} document references"
)
logger.info("=" * 80)

return agent_response.references
return resolved_references

def _resolve_references_internally(
self, references: List[PageReference]
) -> List[PageReference]:
"""Resolve PageReference objects using tracked pages."""
resolved_references = []

for ref in references:
uri_str = str(ref.uri)

# Try to find the page in our tracked pages
if uri_str in self._accessed_pages:
# Create a new reference with the resolved page
resolved_ref = PageReference(
uri=ref.uri, score=ref.score, explanation=ref.explanation
)
resolved_ref.page = self._accessed_pages[uri_str]
resolved_references.append(resolved_ref)
logger.debug(f"Resolved reference: {uri_str}")
else:
# Keep the original reference if we can't resolve it
resolved_references.append(ref)
logger.debug(f"Could not resolve reference: {uri_str}")

return resolved_references

async def _handle_agent_action(
self,
Expand All @@ -321,7 +358,10 @@ async def _execute_tool(self, action: AgentAction, iteration: int) -> str:
if toolkit is None:
raise ValueError(f"Tool '{action.action}' not found in any toolkit")

result = await toolkit.invoke_tool(action.action, action.action_input)
# Execute tool with page tracking
result = await toolkit.invoke_tool(
action.action, action.action_input, callbacks=[self._track_pages]
)

observation = Observation(action=action.action, result=result)
observation_content = observation.to_json()
Expand Down Expand Up @@ -405,6 +445,13 @@ def _parse_llm_output(
).model_dump(),
)

def _track_pages(self, tool_name: str, pages: Sequence[Page]) -> None:
"""Track pages from tool results."""
for page in pages:
uri_str = str(page.uri)
self._accessed_pages[uri_str] = page
logger.debug(f"Tracked page from {tool_name}: {uri_str}")

def _extract_json_from_markdown(self, text: str) -> str:
"""Extract JSON content from markdown code blocks."""
text = text.strip()
Expand Down
19 changes: 18 additions & 1 deletion src/praga_core/agents/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
# Type variable bound to Document for generic pagination
T = TypeVar("T", bound=Page)

# Tool callback types
ToolCallback = Callable[[str, Sequence[Page]], None]


@dataclass(frozen=True)
class PaginatedResponse(Generic[T], ABCSequence[T]):
Expand Down Expand Up @@ -291,7 +294,11 @@ async def __call__(self, **kwargs: Any) -> ToolReturnType:
else:
return await self._handle_client_side_pagination(**kwargs)

async def invoke(self, raw_input: Union[str, Dict[str, Any]]) -> Dict[str, Any]:
async def invoke(
self,
raw_input: Union[str, Dict[str, Any]],
callbacks: Optional[List[ToolCallback]] = None,
) -> Dict[str, Any]:
"""Execute the tool with the given input and serialize the response."""
try:
kwargs = self._prepare_arguments(raw_input)
Expand All @@ -302,6 +309,16 @@ async def invoke(self, raw_input: Union[str, Dict[str, Any]]) -> Dict[str, Any]:
"references": [],
"error_message": "No matching documents found",
}

# Execute callbacks before serialization if provided
if callbacks:
# Extract the actual pages from the result
pages = (
result.results if isinstance(result, PaginatedResponse) else result
)
for callback in callbacks:
callback(self.name, pages)

return self._serialize_result(result)

except ValueError as e:
Expand Down
10 changes: 7 additions & 3 deletions src/praga_core/agents/toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Awaitable,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
Expand Down Expand Up @@ -145,7 +146,10 @@ def get_tool(self, name: str) -> Tool:
return self._tools[name]

async def invoke_tool(
self, name: str, raw_input: Union[str, Dict[str, Any]]
self,
name: str,
raw_input: Union[str, Dict[str, Any]],
callbacks: Optional[List[Callable[[str, Sequence[Page]], None]]] = None,
) -> Dict[str, Any]:
"""Invoke a tool by name with pagination support."""
tool = self.get_tool(name)
Expand All @@ -157,12 +161,12 @@ async def invoke_tool(
name=name, type="tool", show_input="python", language="python"
) as step:
step.input = raw_input
response = await tool.invoke(raw_input)
response = await tool.invoke(raw_input, callbacks)
step.output = response
return response
except (ImportError, AttributeError):
pass
return await tool.invoke(raw_input)
return await tool.invoke(raw_input, callbacks)

@property
def tools(self) -> Dict[str, Tool]:
Expand Down
13 changes: 0 additions & 13 deletions src/praga_core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ async def search(
self,
instruction: str,
retriever: Optional[RetrieverAgentBase] = None,
resolve_references: bool = True,
) -> SearchResponse:
"""Execute search using the provided retriever."""
active_retriever = retriever or self.retriever
Expand All @@ -76,8 +75,6 @@ async def search(
)

results = await self._search(instruction, active_retriever)
if resolve_references:
results = await self._resolve_references(results)
return SearchResponse(results=results)

async def _search(
Expand All @@ -95,16 +92,6 @@ async def _search(
results = await retriever.search(instruction)
return results

async def _resolve_references(
self, results: List[PageReference]
) -> List[PageReference]:
"""Resolve references to pages by calling get_page."""
uris = [ref.uri for ref in results]
pages = await self.get_pages(uris)
for ref, page in zip(results, pages):
ref.page = page
return results

@property
def root(self) -> str:
"""Get the root path for this context."""
Expand Down
61 changes: 31 additions & 30 deletions tests/core/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,8 @@ async def test_search_with_context_retriever(
assert len(result.results) == 3
assert mock_retriever.search_calls == ["test query"]

# Verify pages were resolved
for ref in result.results:
assert ref._page is not None
# Note: Reference resolution is now handled by the agent, not the context
# The mock agent returns the same references it was given

@pytest.mark.asyncio
async def test_search_with_parameter_retriever(
Expand All @@ -167,20 +166,19 @@ async def test_search_no_retriever_error(self, context) -> None:
await context.search("test query")

@pytest.mark.asyncio
async def test_search_without_resolve_references(
async def test_search_returns_agent_results_directly(
self, context, sample_page_references: List[PageReference]
) -> None:
"""Test search without resolving references."""
"""Test search returns results directly from agent."""
mock_retriever = MockRetrieverAgent(sample_page_references)
context.retriever = mock_retriever

result = await context.search("test query", resolve_references=False)
result = await context.search("test query")

assert isinstance(result, SearchResponse)
assert len(result.results) == 3
# Verify pages were NOT resolved
for ref in result.results:
assert ref._page is None
# Note: Reference resolution now happens in the agent, not context
# So pages may or may not be resolved depending on agent implementation

@pytest.mark.asyncio
async def test_search_parameter_retriever_overrides_context(
Expand All @@ -203,22 +201,29 @@ async def test_search_parameter_retriever_overrides_context(
assert param_retriever.search_calls == ["test query"]


class TestReferenceResolution:
"""Test reference resolution functionality."""
class TestPageResolution:
"""Test page resolution functionality through context routing."""

@pytest.mark.asyncio
async def test_resolve_references(
async def test_get_pages_resolves_correctly(
self, context, sample_page_references: List[PageReference]
) -> None:
"""Test resolving page references."""
"""Test resolving page references using get_pages method."""
context.route("document")(document_page_handler)
context.route("alternate")(alternate_page_handler)

resolved_refs = await context._resolve_references(sample_page_references)
uris = [ref.uri for ref in sample_page_references]
resolved_pages = await context.get_pages(uris)

assert len(resolved_pages) == 3
assert resolved_pages[0] is not None
assert resolved_pages[1] is not None
assert resolved_pages[2] is not None

assert len(resolved_refs) == 3
for ref in resolved_refs:
assert ref.page is not None
# Check correct types were used
assert isinstance(resolved_pages[0], DocumentPage)
assert isinstance(resolved_pages[1], DocumentPage)
assert isinstance(resolved_pages[2], AlternateTestPage)


class TestIntegration:
Expand All @@ -245,22 +250,18 @@ async def test_mixed_page_types_workflow(self, context) -> None:
context.route("alternate")(alternate_page_handler)

# Create mixed references
refs = [
PageReference(
uri=PageURI(root="test", type="document", id="doc1", version=1)
),
PageReference(
uri=PageURI(root="test", type="alternate", id="alt1", version=1)
),
uris = [
PageURI(root="test", type="document", id="doc1", version=1),
PageURI(root="test", type="alternate", id="alt1", version=1),
]

resolved_refs = await context._resolve_references(refs)
resolved_pages = await context.get_pages(uris)

assert len(resolved_refs) == 2
assert isinstance(resolved_refs[0].page, DocumentPage)
assert isinstance(resolved_refs[1].page, AlternateTestPage)
assert resolved_refs[0].page.title == "Test Page doc1"
assert resolved_refs[1].page.name == "Alternate Page alt1"
assert len(resolved_pages) == 2
assert isinstance(resolved_pages[0], DocumentPage)
assert isinstance(resolved_pages[1], AlternateTestPage)
assert resolved_pages[0].title == "Test Page doc1"
assert resolved_pages[1].name == "Alternate Page alt1"


class TestValidatorIntegration:
Expand Down
Loading