From 50f3d7555c74debeb0c0463334dc5ca4789574e3 Mon Sep 17 00:00:00 2001 From: observerw Date: Sat, 10 Jan 2026 16:20:17 +0800 Subject: [PATCH 01/10] feat: add Relation API support --- docs/schemas/draft/relation.md | 84 ++++--- src/lsap/capability/__init__.py | 2 + src/lsap/capability/relation.py | 147 +++++++++++ src/lsap/schema/__init__.py | 5 + src/lsap/schema/draft/relation.py | 54 ---- src/lsap/schema/relation.py | 83 +++++++ tests/test_relation.py | 399 ++++++++++++++++++++++++++++++ 7 files changed, 682 insertions(+), 92 deletions(-) create mode 100644 src/lsap/capability/relation.py delete mode 100644 src/lsap/schema/draft/relation.py create mode 100644 src/lsap/schema/relation.py create mode 100644 tests/test_relation.py diff --git a/docs/schemas/draft/relation.md b/docs/schemas/draft/relation.md index aec2020..cb25769 100644 --- a/docs/schemas/draft/relation.md +++ b/docs/schemas/draft/relation.md @@ -1,45 +1,47 @@ # Relation API -The Relation API allows finding all call chains (paths) that connect two specific symbols. This is useful for understanding how one part of the system interacts with another, validating architectural dependencies, or impact analysis. +**Core Value**: Trace the call path between two symbols — answering "how does A reach B?" -It leverages the [Hierarchy API](hierarchy.md) to trace call relationship paths. +This is a high-value query for: + +- **Code Flow Understanding**: How does `handle_request` eventually call `db.query`? +- **Impact Analysis**: If I modify function X, which entry points are affected? +- **Architecture Validation**: Verify that module A never directly/indirectly calls module B ## RelationRequest -| Field | Type | Default | Description | -| :---------- | :-------------------- | :------- | :--------------------------------------- | -| `source` | [`Locate`](locate.md) | Required | The starting symbol for the path search. | -| `target` | [`Locate`](locate.md) | Required | The ending symbol for the path search. | -| `max_depth` | `number` | `10` | Maximum depth to search for connections. | +| Field | Type | Default | Description | +| :---------- | :---------------------------- | :------- | :--------------------------------------- | +| `source` | [`Locate`](../locate.md) | Required | The starting symbol for the path search. | +| `target` | [`Locate`](../locate.md) | Required | The ending symbol for the path search. | +| `max_depth` | `number` | `10` | Maximum search depth. | ## RelationResponse -| Field | Type | Description | -| :---------- | :------------------ | :--------------------------------------------------------------------------- | -| `source` | `HierarchyItem` | The resolved source symbol. | -| `target` | `HierarchyItem` | The resolved target symbol. | -| `chains` | `HierarchyItem[][]` | List of paths connecting source to target. Each path is a sequence of items. | -| `max_depth` | `number` | The maximum depth used for the search. | +| Field | Type | Description | +| :---------- | :---------------- | :---------------------------------------------------- | +| `request` | `RelationRequest` | The original request. | +| `source` | `ChainNode` | The resolved source symbol. | +| `target` | `ChainNode` | The resolved target symbol. | +| `chains` | `ChainNode[][]` | All paths found. Each path is a sequence of nodes. | +| `max_depth` | `number` | The maximum depth used for the search. | -## Implementation Guide +### ChainNode -This API is implemented by orchestrating standard LSP `Call Hierarchy` requests. +A lightweight symbol representation for path display: -### Algorithm: Bidirectional Search +| Field | Type | Description | +| :---------- | :------- | :--------------------------------- | +| `name` | `string` | Symbol name (e.g., `get_user`) | +| `kind` | `string` | Symbol kind (e.g., `Function`) | +| `file_path` | `Path` | File containing the symbol | +| `detail` | `string` | Optional: signature or extra info | -1. **Resolve Symbols**: - - Use `textDocument/definition` or `textDocument/documentSymbol` to resolve the `source` and `target` locations to valid LSP `CallHierarchyItem`s using `textDocument/prepareCallHierarchy`. -2. **Breadth-First Search (BFS)**: - - Perform `callHierarchy/outgoingCalls` from the `source` item. - - (Optional optimization) Simultaneously perform `callHierarchy/incomingCalls` from the `target` item. - - Maintain a `visited` set to detect and break recursive cycles. -3. **Path Reconstruction**: - - When the search frontiers meet (or one reaches the other end), reconstruction the full path. - - Filter out paths that exceed `max_depth`. +> **Design Note**: Unlike `HierarchyItem`, `ChainNode` has no `level` or `is_cycle` fields — the array index naturally represents position in the chain. -## Example Usage +## Example -### Scenario 1: How does `handle_request` reach `db.query`? +### How does `handle_request` reach `db.query`? #### Request @@ -47,21 +49,17 @@ This API is implemented by orchestrating standard LSP `Call Hierarchy` requests. { "source": { "file_path": "src/controllers.py", - "scope": { - "symbol_path": ["handle_request"] - } + "scope": { "symbol_path": ["handle_request"] } }, "target": { "file_path": "src/db.py", - "scope": { - "symbol_path": ["query"] - } + "scope": { "symbol_path": ["query"] } }, "max_depth": 5 } ``` -#### Markdown Rendered for LLM +#### Response (Markdown Rendered) ```markdown # Relation: `handle_request` → `query` @@ -69,19 +67,29 @@ This API is implemented by orchestrating standard LSP `Call Hierarchy` requests. Found 2 call chain(s): ### Chain 1 - 1. **handle_request** (`Function`) - `src/controllers.py` 2. **UserService.get_user** (`Method`) - `src/services/user.py` 3. **db.query** (`Function`) - `src/db.py` ### Chain 2 - 1. **handle_request** (`Function`) - `src/controllers.py` 2. **AuthService.validate_token** (`Method`) - `src/services/auth.py` 3. **SessionManager.get_session** (`Method`) - `src/services/session.py` 4. **db.query** (`Function`) - `src/db.py` ``` -## Pending Issues +## Implementation + +Orchestrates LSP `Call Hierarchy` requests: + +1. **Resolve**: Use `textDocument/prepareCallHierarchy` to get `CallHierarchyItem` for both endpoints +2. **Search**: BFS via `callHierarchy/outgoingCalls` from source (optionally bidirectional with `incomingCalls` from target) +3. **Reconstruct**: Build paths when frontiers meet; filter by `max_depth` + +## Design Decisions -- **TBD**: Search algorithm efficiency for large-scale dependency graphs and path filtering criteria. +| Decision | Rationale | +| :------- | :-------- | +| No pagination | `max_depth` bounds result size; path enumeration is typically small | +| `ChainNode` over `HierarchyItem` | Chains are linear; no tree semantics needed | +| Bidirectional search optional | Optimization for large graphs; not required for correctness | diff --git a/src/lsap/capability/__init__.py b/src/lsap/capability/__init__.py index 3b4a144..114ed0b 100644 --- a/src/lsap/capability/__init__.py +++ b/src/lsap/capability/__init__.py @@ -5,6 +5,7 @@ from .locate import LocateCapability from .outline import OutlineCapability from .reference import ReferenceCapability +from .relation import RelationCapability from .rename import RenameExecuteCapability, RenamePreviewCapability from .search import SearchCapability from .symbol import SymbolCapability @@ -16,6 +17,7 @@ class Capabilities(TypedDict): locate: LocateCapability outline: OutlineCapability references: ReferenceCapability + relation: RelationCapability rename_preview: RenamePreviewCapability rename_execute: RenameExecuteCapability search: SearchCapability diff --git a/src/lsap/capability/relation.py b/src/lsap/capability/relation.py new file mode 100644 index 0000000..77b2471 --- /dev/null +++ b/src/lsap/capability/relation.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +from collections import deque +from functools import cached_property +from pathlib import Path +from typing import override + +from attrs import define +from lsp_client.capability.request import WithRequestCallHierarchy +from lsprotocol.types import ( + CallHierarchyItem, + CallHierarchyOutgoingCallsParams, +) + +from lsap.schema.locate import LocateRequest +from lsap.schema.relation import ChainNode, RelationRequest, RelationResponse +from lsap.utils.capability import ensure_capability + +from .abc import Capability +from .locate import LocateCapability + + +@define +class RelationCapability(Capability[RelationRequest, RelationResponse]): + @cached_property + def locate(self) -> LocateCapability: + return LocateCapability(self.client) + + @override + async def __call__(self, req: RelationRequest) -> RelationResponse | None: + # Resolve source symbol + source_req = LocateRequest(locate=req.source) + if not (source_loc := await self.locate(source_req)): + return None + + # Resolve target symbol + target_req = LocateRequest(locate=req.target) + if not (target_loc := await self.locate(target_req)): + return None + + # Get CallHierarchyItems for source and target + call_hierarchy = ensure_capability(self.client, WithRequestCallHierarchy) + + source_items = await call_hierarchy.prepare_call_hierarchy( + source_loc.file_path, source_loc.position.to_lsp() + ) + if not source_items: + return None + + target_items = await call_hierarchy.prepare_call_hierarchy( + target_loc.file_path, target_loc.position.to_lsp() + ) + if not target_items: + return None + + # For simplicity, use the first (primary) symbol if multiple exist + # This handles cases like overloaded functions or template instantiations + # TODO: Consider searching paths for all source-target combinations if needed + source_node = self._to_chain_node(source_items[0]) + target_node = self._to_chain_node(target_items[0]) + + # Find all paths from source to target + chains = await self._find_paths( + call_hierarchy, source_items[0], target_items[0], req.max_depth + ) + + return RelationResponse( + request=req, + source=source_node, + target=target_node, + chains=chains, + ) + + def _to_chain_node(self, item: CallHierarchyItem) -> ChainNode: + """Convert CallHierarchyItem to ChainNode""" + file_path = Path(self.client.from_uri(item.uri, relative=False)) + return ChainNode( + name=item.name, + kind=item.kind.name if hasattr(item.kind, "name") else str(item.kind), + file_path=file_path, + detail=item.detail, + ) + + async def _find_paths( + self, + call_hierarchy: WithRequestCallHierarchy, + source: CallHierarchyItem, + target: CallHierarchyItem, + max_depth: int, + ) -> list[list[ChainNode]]: + """ + Find all paths from source to target using BFS. + + Returns a list of chains, where each chain is a list of ChainNodes + representing a path from source to target. + """ + target_key = self._item_key(target) + found_chains: list[list[ChainNode]] = [] + + # BFS queue: (current_item, path_so_far, depth) + queue: deque[tuple[CallHierarchyItem, list[ChainNode], int]] = deque() + queue.append((source, [self._to_chain_node(source)], 0)) + + # Track visited nodes to avoid cycles + visited: set[str] = set() + + while queue: + current_item, path, depth = queue.popleft() + current_key = self._item_key(current_item) + + # Skip if we've exceeded max depth + if depth >= max_depth: + continue + + # Skip if already visited (cycle detection) + if current_key in visited: + continue + visited.add(current_key) + + # Check if we've reached the target + if current_key == target_key: + found_chains.append(path) + continue + + # Get outgoing calls from current item + outgoing_calls = ( + await call_hierarchy._request_call_hierarchy_outgoing_calls( + CallHierarchyOutgoingCallsParams(item=current_item) + ) + ) + + if not outgoing_calls: + continue + + # Add each outgoing call to the queue + for call in outgoing_calls: + next_item = call.to + next_node = self._to_chain_node(next_item) + next_path = path + [next_node] + queue.append((next_item, next_path, depth + 1)) + + return found_chains + + def _item_key(self, item: CallHierarchyItem) -> str: + """Generate a unique key for a CallHierarchyItem""" + file_path = self.client.from_uri(item.uri, relative=False) + return f"{file_path}:{item.range.start.line}:{item.range.start.character}:{item.name}" diff --git a/src/lsap/schema/__init__.py b/src/lsap/schema/__init__.py index 770ef3b..d4432ee 100644 --- a/src/lsap/schema/__init__.py +++ b/src/lsap/schema/__init__.py @@ -6,6 +6,7 @@ from .locate import LocateRequest, LocateResponse from .outline import OutlineRequest, OutlineResponse from .reference import ReferenceRequest, ReferenceResponse +from .relation import RelationRequest, RelationResponse from .rename import ( RenameExecuteRequest, RenameExecuteResponse, @@ -42,6 +43,10 @@ class Schema(NamedTuple): request=ReferenceRequest, response=ReferenceResponse, ), + "relation": Schema( + request=RelationRequest, + response=RelationResponse, + ), "rename_preview": Schema( request=RenamePreviewRequest, response=RenamePreviewResponse, diff --git a/src/lsap/schema/draft/relation.py b/src/lsap/schema/draft/relation.py deleted file mode 100644 index 07eeb43..0000000 --- a/src/lsap/schema/draft/relation.py +++ /dev/null @@ -1,54 +0,0 @@ -from typing import Final - -from pydantic import ConfigDict - -from lsap.schema.abc import Request, Response -from lsap.schema.draft.hierarchy import HierarchyItem -from lsap.schema.locate import Locate - - -class RelationRequest(Request): - """ - Finds call chains connecting two symbols. - - Uses call hierarchy to trace paths from source to target. - """ - - source: Locate - target: Locate - - max_depth: int = 10 - """Maximum depth to search for connections""" - - -markdown_template: Final = """ -# Relation: `{{ source.name }}` → `{{ target.name }}` - -{% if chains.size > 0 %} -Found {{ chains | size }} call chain(s): - -{% for chain in chains %} -### Chain {{ forloop.index }} -{% for item in chain %} -{{ forloop.index }}. **{{ item.name }}** (`{{ item.kind }}`) - `{{ item.file_path }}` -{% endfor %} -{% endfor %} -{% else %} -No connection found between `{{ source.name }}` and `{{ target.name }}` (depth: {{ max_depth }}). -{% endif %} -""" - - -class RelationResponse(Response): - source: HierarchyItem - target: HierarchyItem - chains: list[list[HierarchyItem]] - """List of paths, where each path is a sequence of items from source to target.""" - - max_depth: int - - model_config = ConfigDict( - json_schema_extra={ - "markdown": markdown_template, - } - ) diff --git a/src/lsap/schema/relation.py b/src/lsap/schema/relation.py new file mode 100644 index 0000000..45f57c5 --- /dev/null +++ b/src/lsap/schema/relation.py @@ -0,0 +1,83 @@ +from pathlib import Path +from typing import Final + +from pydantic import BaseModel, ConfigDict + +from lsap.schema.abc import Request, Response +from lsap.schema.locate import Locate + + +class ChainNode(BaseModel): + """ + A node in a call chain. + """ + + name: str + kind: str + file_path: Path + detail: str | None = None + + +class RelationRequest(Request): + """ + Finds call chains connecting two symbols. + + Answers the question: "How does A reach B?" + + Use cases: + - Code flow understanding: How does handle_request reach db.query? + - Impact analysis: Which entry points are affected if I modify X? + - Architecture validation: Verify module A never calls module B + """ + + source: Locate + target: Locate + + max_depth: int = 10 + """Maximum depth to search for connections""" + + +markdown_template: Final = """ +# Relation: `{{ source.name }}` → `{{ target.name }}` + +{% if chains.size > 0 %} +Found {{ chains.size }} call chain(s): + +{% for chain in chains %} +### Chain {{ forloop.index }} +{% for node in chain %} +{{ forloop.index }}. **{{ node.name }}** (`{{ node.kind }}`) - `{{ node.file_path }}` +{%- if node.detail %} — {{ node.detail }}{% endif %} +{% endfor %} +{% endfor %} +{% else %} +No connection found between `{{ source.name }}` and `{{ target.name }}` within depth {{ request.max_depth }}. +{% endif %} +""" + + +class RelationResponse(Response): + request: RelationRequest + """The original request""" + + source: ChainNode + """The resolved source symbol""" + + target: ChainNode + """The resolved target symbol""" + + chains: list[list[ChainNode]] + """All paths found. Each path is a sequence of nodes from source to target.""" + + model_config = ConfigDict( + json_schema_extra={ + "markdown": markdown_template, + } + ) + + +__all__ = [ + "ChainNode", + "RelationRequest", + "RelationResponse", +] diff --git a/tests/test_relation.py b/tests/test_relation.py new file mode 100644 index 0000000..f5b6668 --- /dev/null +++ b/tests/test_relation.py @@ -0,0 +1,399 @@ +""" +Functional tests for Relation API. + +Tests the call chain discovery capability that answers "how does A reach B?" +""" + +from pathlib import Path + +from lsap.schema.locate import Locate, SymbolScope +from lsap.schema.relation import ChainNode, RelationRequest, RelationResponse + + +def test_chain_node_creation(): + """Test creating a ChainNode.""" + node = ChainNode( + name="handle_request", + kind="Function", + file_path=Path("src/controllers.py"), + ) + assert node.name == "handle_request" + assert node.kind == "Function" + assert node.file_path == Path("src/controllers.py") + assert node.detail is None + + +def test_chain_node_with_detail(): + """Test creating a ChainNode with detail.""" + node = ChainNode( + name="UserService.get_user", + kind="Method", + file_path=Path("src/services/user.py"), + detail="(user_id: int) -> User", + ) + assert node.name == "UserService.get_user" + assert node.kind == "Method" + assert node.detail == "(user_id: int) -> User" + + +def test_relation_request_creation(): + """Test creating a RelationRequest.""" + req = RelationRequest( + source=Locate( + file_path=Path("src/controllers.py"), + scope=SymbolScope(symbol_path=["handle_request"]), + ), + target=Locate( + file_path=Path("src/db.py"), + scope=SymbolScope(symbol_path=["query"]), + ), + ) + assert req.source.file_path == Path("src/controllers.py") + assert req.target.file_path == Path("src/db.py") + assert req.max_depth == 10 # default + + +def test_relation_request_with_custom_max_depth(): + """Test creating a RelationRequest with custom max_depth.""" + req = RelationRequest( + source=Locate( + file_path=Path("src/controllers.py"), + scope=SymbolScope(symbol_path=["handle_request"]), + ), + target=Locate( + file_path=Path("src/db.py"), + scope=SymbolScope(symbol_path=["query"]), + ), + max_depth=5, + ) + assert req.max_depth == 5 + + +def test_relation_response_with_single_chain(): + """Test RelationResponse with a single call chain.""" + source = ChainNode( + name="handle_request", + kind="Function", + file_path=Path("src/controllers.py"), + ) + target = ChainNode( + name="query", + kind="Function", + file_path=Path("src/db.py"), + ) + + chain = [ + source, + ChainNode( + name="UserService.get_user", + kind="Method", + file_path=Path("src/services/user.py"), + ), + target, + ] + + req = RelationRequest( + source=Locate( + file_path=Path("src/controllers.py"), + scope=SymbolScope(symbol_path=["handle_request"]), + ), + target=Locate( + file_path=Path("src/db.py"), + scope=SymbolScope(symbol_path=["query"]), + ), + ) + + resp = RelationResponse( + request=req, + source=source, + target=target, + chains=[chain], + ) + + assert len(resp.chains) == 1 + assert len(resp.chains[0]) == 3 + assert resp.chains[0][0].name == "handle_request" + assert resp.chains[0][1].name == "UserService.get_user" + assert resp.chains[0][2].name == "query" + + +def test_relation_response_with_multiple_chains(): + """Test RelationResponse with multiple call chains (example from docs).""" + source = ChainNode( + name="handle_request", + kind="Function", + file_path=Path("src/controllers.py"), + ) + target = ChainNode( + name="query", + kind="Function", + file_path=Path("src/db.py"), + ) + + # Chain 1: handle_request -> UserService.get_user -> db.query + chain1 = [ + source, + ChainNode( + name="UserService.get_user", + kind="Method", + file_path=Path("src/services/user.py"), + ), + target, + ] + + # Chain 2: handle_request -> AuthService.validate_token -> SessionManager.get_session -> db.query + chain2 = [ + source, + ChainNode( + name="AuthService.validate_token", + kind="Method", + file_path=Path("src/services/auth.py"), + ), + ChainNode( + name="SessionManager.get_session", + kind="Method", + file_path=Path("src/services/session.py"), + ), + target, + ] + + req = RelationRequest( + source=Locate( + file_path=Path("src/controllers.py"), + scope=SymbolScope(symbol_path=["handle_request"]), + ), + target=Locate( + file_path=Path("src/db.py"), + scope=SymbolScope(symbol_path=["query"]), + ), + max_depth=5, + ) + + resp = RelationResponse( + request=req, + source=source, + target=target, + chains=[chain1, chain2], + ) + + assert len(resp.chains) == 2 + assert len(resp.chains[0]) == 3 + assert len(resp.chains[1]) == 4 + assert resp.chains[0][1].name == "UserService.get_user" + assert resp.chains[1][1].name == "AuthService.validate_token" + assert resp.chains[1][2].name == "SessionManager.get_session" + + +def test_relation_response_with_no_chains(): + """Test RelationResponse when no connection found.""" + source = ChainNode( + name="handle_request", + kind="Function", + file_path=Path("src/controllers.py"), + ) + target = ChainNode( + name="unrelated_function", + kind="Function", + file_path=Path("src/utils.py"), + ) + + req = RelationRequest( + source=Locate( + file_path=Path("src/controllers.py"), + scope=SymbolScope(symbol_path=["handle_request"]), + ), + target=Locate( + file_path=Path("src/utils.py"), + scope=SymbolScope(symbol_path=["unrelated_function"]), + ), + ) + + resp = RelationResponse( + request=req, + source=source, + target=target, + chains=[], + ) + + assert len(resp.chains) == 0 + + +def test_relation_response_markdown_format_with_chains(): + """Test markdown formatting of RelationResponse with chains.""" + source = ChainNode( + name="handle_request", + kind="Function", + file_path=Path("src/controllers.py"), + ) + target = ChainNode( + name="query", + kind="Function", + file_path=Path("src/db.py"), + ) + + chain = [ + source, + ChainNode( + name="UserService.get_user", + kind="Method", + file_path=Path("src/services/user.py"), + detail="(user_id: int) -> User", + ), + target, + ] + + req = RelationRequest( + source=Locate( + file_path=Path("src/controllers.py"), + scope=SymbolScope(symbol_path=["handle_request"]), + ), + target=Locate( + file_path=Path("src/db.py"), + scope=SymbolScope(symbol_path=["query"]), + ), + ) + + resp = RelationResponse( + request=req, + source=source, + target=target, + chains=[chain], + ) + + markdown = resp.format() + + # Check that markdown contains key elements + assert "handle_request" in markdown + assert "query" in markdown + assert "Found 1 call chain(s)" in markdown + assert "Chain 1" in markdown + assert "UserService.get_user" in markdown + assert "Method" in markdown + assert "(user_id: int) -> User" in markdown + + +def test_relation_response_markdown_format_no_chains(): + """Test markdown formatting of RelationResponse with no chains.""" + source = ChainNode( + name="handle_request", + kind="Function", + file_path=Path("src/controllers.py"), + ) + target = ChainNode( + name="unrelated_function", + kind="Function", + file_path=Path("src/utils.py"), + ) + + req = RelationRequest( + source=Locate( + file_path=Path("src/controllers.py"), + scope=SymbolScope(symbol_path=["handle_request"]), + ), + target=Locate( + file_path=Path("src/utils.py"), + scope=SymbolScope(symbol_path=["unrelated_function"]), + ), + max_depth=5, + ) + + resp = RelationResponse( + request=req, + source=source, + target=target, + chains=[], + ) + + markdown = resp.format() + + # Check that markdown indicates no connection + assert "handle_request" in markdown + assert "unrelated_function" in markdown + assert "No connection found" in markdown + assert "within depth 5" in markdown + + +def test_chain_node_equality(): + """Test that ChainNodes with same data are equal.""" + node1 = ChainNode( + name="foo", + kind="Function", + file_path=Path("test.py"), + detail="detail", + ) + node2 = ChainNode( + name="foo", + kind="Function", + file_path=Path("test.py"), + detail="detail", + ) + assert node1 == node2 + + +def test_relation_response_chain_order(): + """Test that chains preserve order.""" + source = ChainNode( + name="a", + kind="Function", + file_path=Path("a.py"), + ) + middle1 = ChainNode( + name="b", + kind="Function", + file_path=Path("b.py"), + ) + middle2 = ChainNode( + name="c", + kind="Function", + file_path=Path("c.py"), + ) + target = ChainNode( + name="d", + kind="Function", + file_path=Path("d.py"), + ) + + chain = [source, middle1, middle2, target] + + req = RelationRequest( + source=Locate( + file_path=Path("a.py"), + scope=SymbolScope(symbol_path=["a"]), + ), + target=Locate( + file_path=Path("d.py"), + scope=SymbolScope(symbol_path=["d"]), + ), + ) + + resp = RelationResponse( + request=req, + source=source, + target=target, + chains=[chain], + ) + + # Verify order is preserved + assert resp.chains[0][0].name == "a" + assert resp.chains[0][1].name == "b" + assert resp.chains[0][2].name == "c" + assert resp.chains[0][3].name == "d" + + +def test_relation_request_with_nested_symbol_path(): + """Test RelationRequest with nested symbol paths.""" + req = RelationRequest( + source=Locate( + file_path=Path("src/services/user.py"), + scope=SymbolScope(symbol_path=["UserService", "get_user"]), + ), + target=Locate( + file_path=Path("src/db.py"), + scope=SymbolScope(symbol_path=["Database", "query"]), + ), + max_depth=3, + ) + + assert req.source.scope.symbol_path == ["UserService", "get_user"] + assert req.target.scope.symbol_path == ["Database", "query"] + assert req.max_depth == 3 From 7310cef7d03e11df36941d27e6a9e00c07c41924 Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Sat, 10 Jan 2026 17:36:53 +0800 Subject: [PATCH 02/10] fix: correct BFS path-finding in Relation API and add integration tests (#17) * Initial plan * fix: address PR review comments for Relation API - Fix documentation: Remove incorrect max_depth field from RelationResponse - Fix BFS bug: Move target check before visited.add() to find all paths - Fix BFS optimization: Check visited before adding to queue Co-authored-by: observerw <20661574+observerw@users.noreply.github.com> * test: add comprehensive integration tests for RelationCapability - Add mock LSP client with call hierarchy and document symbol support - Test single path discovery - Test multiple paths between symbols - Test scenario with no path - Test max_depth boundary conditions - Test cycle detection with recursive calls - Test direct calls - Test paths of different lengths Co-authored-by: observerw <20661574+observerw@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: observerw <20661574+observerw@users.noreply.github.com> --- docs/schemas/draft/relation.md | 3 +- src/lsap/capability/relation.py | 19 +- tests/test_relation.py | 421 +++++++++++++++++++++++++++++++- 3 files changed, 433 insertions(+), 10 deletions(-) diff --git a/docs/schemas/draft/relation.md b/docs/schemas/draft/relation.md index cb25769..6b49e96 100644 --- a/docs/schemas/draft/relation.md +++ b/docs/schemas/draft/relation.md @@ -24,7 +24,8 @@ This is a high-value query for: | `source` | `ChainNode` | The resolved source symbol. | | `target` | `ChainNode` | The resolved target symbol. | | `chains` | `ChainNode[][]` | All paths found. Each path is a sequence of nodes. | -| `max_depth` | `number` | The maximum depth used for the search. | + +The maximum depth used for the search is available as `request.max_depth`, since the response includes the original `RelationRequest`. ### ChainNode diff --git a/src/lsap/capability/relation.py b/src/lsap/capability/relation.py index 77b2471..14ec234 100644 --- a/src/lsap/capability/relation.py +++ b/src/lsap/capability/relation.py @@ -108,6 +108,11 @@ async def _find_paths( current_item, path, depth = queue.popleft() current_key = self._item_key(current_item) + # Check if we've reached the target + if current_key == target_key: + found_chains.append(path) + continue + # Skip if we've exceeded max depth if depth >= max_depth: continue @@ -117,11 +122,6 @@ async def _find_paths( continue visited.add(current_key) - # Check if we've reached the target - if current_key == target_key: - found_chains.append(path) - continue - # Get outgoing calls from current item outgoing_calls = ( await call_hierarchy._request_call_hierarchy_outgoing_calls( @@ -135,9 +135,12 @@ async def _find_paths( # Add each outgoing call to the queue for call in outgoing_calls: next_item = call.to - next_node = self._to_chain_node(next_item) - next_path = path + [next_node] - queue.append((next_item, next_path, depth + 1)) + next_key = self._item_key(next_item) + # Skip if already visited to prevent redundant queue entries + if next_key not in visited: + next_node = self._to_chain_node(next_item) + next_path = path + [next_node] + queue.append((next_item, next_path, depth + 1)) return found_chains diff --git a/tests/test_relation.py b/tests/test_relation.py index f5b6668..55bb980 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -5,7 +5,35 @@ """ from pathlib import Path - +from contextlib import asynccontextmanager + +import pytest +from lsprotocol.types import ( + CallHierarchyIncomingCall, + CallHierarchyItem, + CallHierarchyOutgoingCall, + CallHierarchyOutgoingCallsParams, + DocumentSymbol, + SymbolKind, +) +from lsprotocol.types import Position as LSPPosition +from lsprotocol.types import Range as LSPRange +from lsp_client.capability.request import ( + WithRequestCallHierarchy, + WithRequestDocumentSymbol, +) +from lsp_client.client.document_state import DocumentStateManager +from lsp_client.protocol import CapabilityClientProtocol +from lsp_client.protocol.lang import LanguageConfig +from lsp_client.utils.config import ConfigurationMap +from lsp_client.utils.workspace import ( + DEFAULT_WORKSPACE_DIR, + Workspace, + WorkspaceFolder, +) +from lsprotocol.types import LanguageKind + +from lsap.capability.relation import RelationCapability from lsap.schema.locate import Locate, SymbolScope from lsap.schema.relation import ChainNode, RelationRequest, RelationResponse @@ -397,3 +425,394 @@ def test_relation_request_with_nested_symbol_path(): assert req.source.scope.symbol_path == ["UserService", "get_user"] assert req.target.scope.symbol_path == ["Database", "query"] assert req.max_depth == 3 + + +# ============================================================================ +# Integration tests with mock LSP client +# ============================================================================ + + +class MockRelationClient( + WithRequestCallHierarchy, WithRequestDocumentSymbol, CapabilityClientProtocol +): + """Mock client for testing RelationCapability with call hierarchy support.""" + + def __init__(self, call_graph: dict[str, list[str]] | None = None): + """ + Initialize mock client with a call graph. + + call_graph: Dictionary mapping symbol names to list of symbols they call. + Example: {"A": ["B", "C"], "B": ["D"]} means A calls B and C, B calls D. + """ + self.call_graph = call_graph or {} + self._workspace = Workspace( + { + DEFAULT_WORKSPACE_DIR: WorkspaceFolder( + uri=Path.cwd().as_uri(), + name=DEFAULT_WORKSPACE_DIR, + ) + } + ) + self._config_map = ConfigurationMap() + self._doc_state = DocumentStateManager() + + def from_uri(self, uri: str, *, relative: bool = True) -> Path: + return Path(uri.replace("file://", "")) + + def get_workspace(self) -> Workspace: + return self._workspace + + def get_config_map(self) -> ConfigurationMap: + return self._config_map + + def get_document_state(self) -> DocumentStateManager: + return self._doc_state + + @classmethod + def get_language_config(cls): + return LanguageConfig( + kind=LanguageKind.Python, + suffixes=["py"], + project_files=["pyproject.toml"], + ) + + async def request(self, req, schema): + return None + + async def notify(self, msg): + pass + + async def read_file(self, file_path) -> str: + return "# Mock file content" + + async def write_file(self, uri: str, content: str) -> None: + pass + + @asynccontextmanager + async def open_files(self, *file_paths): + yield + + async def request_document_symbol_list( + self, file_path: Path + ) -> list[DocumentSymbol]: + """Mock document symbol list - returns a single function symbol based on file name.""" + # Extract symbol name from file path (e.g., test_A.py -> A) + name = file_path.stem.replace("test_", "") + if name in self.call_graph or any( + name in calls for calls in self.call_graph.values() + ): + return [ + DocumentSymbol( + name=name, + kind=SymbolKind.Function, + range=LSPRange( + start=LSPPosition(line=0, character=0), + end=LSPPosition(line=1, character=0), + ), + selection_range=LSPRange( + start=LSPPosition(line=0, character=4), + end=LSPPosition(line=0, character=4 + len(name)), + ), + ) + ] + return [] + + async def request_document_symbol_information_list(self, file_path): + return [] + + def _make_call_hierarchy_item(self, name: str) -> CallHierarchyItem: + """Create a mock CallHierarchyItem for a symbol name.""" + return CallHierarchyItem( + name=name, + kind=SymbolKind.Function, + uri=f"file://test_{name}.py", + range=LSPRange( + start=LSPPosition(line=0, character=0), + end=LSPPosition(line=1, character=0), + ), + selection_range=LSPRange( + start=LSPPosition(line=0, character=4), + end=LSPPosition(line=0, character=4 + len(name)), + ), + ) + + async def prepare_call_hierarchy( + self, file_path: Path, position: LSPPosition + ) -> list[CallHierarchyItem] | None: + """Mock prepare_call_hierarchy - returns item based on file path.""" + # Extract symbol name from file path (e.g., test_A.py -> A) + name = file_path.stem.replace("test_", "") + if name in self.call_graph or any( + name in calls for calls in self.call_graph.values() + ): + return [self._make_call_hierarchy_item(name)] + return None + + async def _request_call_hierarchy_outgoing_calls( + self, params: CallHierarchyOutgoingCallsParams + ) -> list[CallHierarchyOutgoingCall] | None: + """Mock outgoing calls based on the call graph.""" + item_name = params.item.name + if item_name not in self.call_graph: + return [] + + outgoing = [] + for target_name in self.call_graph[item_name]: + target_item = self._make_call_hierarchy_item(target_name) + outgoing.append( + CallHierarchyOutgoingCall( + to=target_item, + from_ranges=[ + LSPRange( + start=LSPPosition(line=0, character=0), + end=LSPPosition(line=0, character=1), + ) + ], + ) + ) + return outgoing + + +@pytest.mark.asyncio +async def test_relation_capability_single_path(): + """Test finding a single path between two symbols.""" + # Call graph: A -> B -> C + call_graph = {"A": ["B"], "B": ["C"]} + client = MockRelationClient(call_graph) + capability = RelationCapability(client=client) # type: ignore + + req = RelationRequest( + source=Locate( + file_path=Path("test_A.py"), + scope=SymbolScope(symbol_path=["A"]), + ), + target=Locate( + file_path=Path("test_C.py"), + scope=SymbolScope(symbol_path=["C"]), + ), + max_depth=5, + ) + + resp = await capability(req) + assert resp is not None + assert resp.source.name == "A" + assert resp.target.name == "C" + assert len(resp.chains) == 1 + assert len(resp.chains[0]) == 3 # A -> B -> C + assert resp.chains[0][0].name == "A" + assert resp.chains[0][1].name == "B" + assert resp.chains[0][2].name == "C" + + +@pytest.mark.asyncio +async def test_relation_capability_multiple_paths(): + """Test finding multiple paths between two symbols.""" + # Call graph: A -> B -> D, A -> C -> D (two paths from A to D) + call_graph = {"A": ["B", "C"], "B": ["D"], "C": ["D"]} + client = MockRelationClient(call_graph) + capability = RelationCapability(client=client) # type: ignore + + req = RelationRequest( + source=Locate( + file_path=Path("test_A.py"), + scope=SymbolScope(symbol_path=["A"]), + ), + target=Locate( + file_path=Path("test_D.py"), + scope=SymbolScope(symbol_path=["D"]), + ), + max_depth=5, + ) + + resp = await capability(req) + assert resp is not None + assert resp.source.name == "A" + assert resp.target.name == "D" + assert len(resp.chains) == 2 # Two paths: A->B->D and A->C->D + + # Both chains should start with A and end with D + for chain in resp.chains: + assert chain[0].name == "A" + assert chain[-1].name == "D" + assert len(chain) == 3 # A -> (B or C) -> D + + # Check that we have both paths + middle_nodes = {chain[1].name for chain in resp.chains} + assert middle_nodes == {"B", "C"} + + +@pytest.mark.asyncio +async def test_relation_capability_no_path(): + """Test when no path exists between source and target.""" + # Call graph: A -> B, C -> D (no connection from A to D) + call_graph = {"A": ["B"], "C": ["D"]} + client = MockRelationClient(call_graph) + capability = RelationCapability(client=client) # type: ignore + + req = RelationRequest( + source=Locate( + file_path=Path("test_A.py"), + scope=SymbolScope(symbol_path=["A"]), + ), + target=Locate( + file_path=Path("test_D.py"), + scope=SymbolScope(symbol_path=["D"]), + ), + max_depth=5, + ) + + resp = await capability(req) + assert resp is not None + assert resp.source.name == "A" + assert resp.target.name == "D" + assert len(resp.chains) == 0 + + +@pytest.mark.asyncio +async def test_relation_capability_max_depth(): + """Test that max_depth is respected.""" + # Call graph: A -> B -> C -> D -> E (chain of 4 calls) + call_graph = {"A": ["B"], "B": ["C"], "C": ["D"], "D": ["E"]} + client = MockRelationClient(call_graph) + capability = RelationCapability(client=client) # type: ignore + + # With max_depth=3, we can reach D (3 hops: A->B, B->C, C->D) + req = RelationRequest( + source=Locate( + file_path=Path("test_A.py"), + scope=SymbolScope(symbol_path=["A"]), + ), + target=Locate( + file_path=Path("test_C.py"), + scope=SymbolScope(symbol_path=["C"]), + ), + max_depth=3, + ) + + resp = await capability(req) + assert resp is not None + assert len(resp.chains) == 1 + assert len(resp.chains[0]) == 3 # A -> B -> C + + # With max_depth=2, we can reach C (2 hops: A->B, B->C) + req2 = RelationRequest( + source=Locate( + file_path=Path("test_A.py"), + scope=SymbolScope(symbol_path=["A"]), + ), + target=Locate( + file_path=Path("test_C.py"), + scope=SymbolScope(symbol_path=["C"]), + ), + max_depth=2, + ) + + resp2 = await capability(req2) + assert resp2 is not None + assert len(resp2.chains) == 1 + + # With max_depth=1, we cannot reach C (only B is reachable) + req3 = RelationRequest( + source=Locate( + file_path=Path("test_A.py"), + scope=SymbolScope(symbol_path=["A"]), + ), + target=Locate( + file_path=Path("test_C.py"), + scope=SymbolScope(symbol_path=["C"]), + ), + max_depth=1, + ) + + resp3 = await capability(req3) + assert resp3 is not None + assert len(resp3.chains) == 0 # Cannot reach C within max_depth=1 + + +@pytest.mark.asyncio +async def test_relation_capability_cycle_detection(): + """Test that cycles are properly detected and don't cause infinite loops.""" + # Call graph with cycle: A -> B -> C -> B (cycle between B and C) + # But there's also A -> B -> D (path without cycle) + call_graph = {"A": ["B"], "B": ["C", "D"], "C": ["B"]} + client = MockRelationClient(call_graph) + capability = RelationCapability(client=client) # type: ignore + + req = RelationRequest( + source=Locate( + file_path=Path("test_A.py"), + scope=SymbolScope(symbol_path=["A"]), + ), + target=Locate( + file_path=Path("test_D.py"), + scope=SymbolScope(symbol_path=["D"]), + ), + max_depth=10, + ) + + resp = await capability(req) + assert resp is not None + assert resp.source.name == "A" + assert resp.target.name == "D" + assert len(resp.chains) == 1 # Should find A -> B -> D + assert len(resp.chains[0]) == 3 + assert resp.chains[0][0].name == "A" + assert resp.chains[0][1].name == "B" + assert resp.chains[0][2].name == "D" + + +@pytest.mark.asyncio +async def test_relation_capability_direct_call(): + """Test direct call (source directly calls target).""" + # Call graph: A -> B (direct call) + call_graph = {"A": ["B"]} + client = MockRelationClient(call_graph) + capability = RelationCapability(client=client) # type: ignore + + req = RelationRequest( + source=Locate( + file_path=Path("test_A.py"), + scope=SymbolScope(symbol_path=["A"]), + ), + target=Locate( + file_path=Path("test_B.py"), + scope=SymbolScope(symbol_path=["B"]), + ), + max_depth=5, + ) + + resp = await capability(req) + assert resp is not None + assert len(resp.chains) == 1 + assert len(resp.chains[0]) == 2 # Just A -> B + assert resp.chains[0][0].name == "A" + assert resp.chains[0][1].name == "B" + + +@pytest.mark.asyncio +async def test_relation_capability_different_path_lengths(): + """Test finding paths of different lengths.""" + # Call graph: A -> D (direct), A -> B -> D (2 hops), A -> C -> E -> D (3 hops) + call_graph = {"A": ["B", "C", "D"], "B": ["D"], "C": ["E"], "E": ["D"]} + client = MockRelationClient(call_graph) + capability = RelationCapability(client=client) # type: ignore + + req = RelationRequest( + source=Locate( + file_path=Path("test_A.py"), + scope=SymbolScope(symbol_path=["A"]), + ), + target=Locate( + file_path=Path("test_D.py"), + scope=SymbolScope(symbol_path=["D"]), + ), + max_depth=5, + ) + + resp = await capability(req) + assert resp is not None + assert len(resp.chains) == 3 # Three paths of different lengths + + # Check path lengths + path_lengths = sorted([len(chain) for chain in resp.chains]) + assert path_lengths == [2, 3, 4] # Direct, 2-hop, and 3-hop paths From 5bfe77324fd0c47ccbe98c5d1b75059028ff6a69 Mon Sep 17 00:00:00 2001 From: observerw Date: Sat, 24 Jan 2026 17:51:11 +0800 Subject: [PATCH 03/10] update --- src/lsap/schema/relation.py | 2 +- tests/test_relation.py | 23 +++++++++++------------ 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/src/lsap/schema/relation.py b/src/lsap/schema/relation.py index 45f57c5..d9c2bf3 100644 --- a/src/lsap/schema/relation.py +++ b/src/lsap/schema/relation.py @@ -3,7 +3,7 @@ from pydantic import BaseModel, ConfigDict -from lsap.schema.abc import Request, Response +from lsap.schema._abc import Request, Response from lsap.schema.locate import Locate diff --git a/tests/test_relation.py b/tests/test_relation.py index 55bb980..37b6bb2 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -4,20 +4,10 @@ Tests the call chain discovery capability that answers "how does A reach B?" """ -from pathlib import Path from contextlib import asynccontextmanager +from pathlib import Path import pytest -from lsprotocol.types import ( - CallHierarchyIncomingCall, - CallHierarchyItem, - CallHierarchyOutgoingCall, - CallHierarchyOutgoingCallsParams, - DocumentSymbol, - SymbolKind, -) -from lsprotocol.types import Position as LSPPosition -from lsprotocol.types import Range as LSPRange from lsp_client.capability.request import ( WithRequestCallHierarchy, WithRequestDocumentSymbol, @@ -31,7 +21,16 @@ Workspace, WorkspaceFolder, ) -from lsprotocol.types import LanguageKind +from lsprotocol.types import ( + CallHierarchyItem, + CallHierarchyOutgoingCall, + CallHierarchyOutgoingCallsParams, + DocumentSymbol, + LanguageKind, + SymbolKind, +) +from lsprotocol.types import Position as LSPPosition +from lsprotocol.types import Range as LSPRange from lsap.capability.relation import RelationCapability from lsap.schema.locate import Locate, SymbolScope From 662123499ad07a0b72face6fcf7131ac591ea691 Mon Sep 17 00:00:00 2001 From: observerw Date: Sat, 24 Jan 2026 21:06:57 +0800 Subject: [PATCH 04/10] fix(relation): use list unpacking for path concatenation --- src/lsap/capability/relation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lsap/capability/relation.py b/src/lsap/capability/relation.py index 14ec234..ef54f70 100644 --- a/src/lsap/capability/relation.py +++ b/src/lsap/capability/relation.py @@ -139,7 +139,7 @@ async def _find_paths( # Skip if already visited to prevent redundant queue entries if next_key not in visited: next_node = self._to_chain_node(next_item) - next_path = path + [next_node] + next_path = [*path, next_node] queue.append((next_item, next_path, depth + 1)) return found_chains From 117eac1d59dd3a59b3eb9f5152399b69e2d54cfd Mon Sep 17 00:00:00 2001 From: observerw Date: Sun, 25 Jan 2026 02:19:35 +0800 Subject: [PATCH 05/10] feat(schema): add Inspect API schema for usage-oriented symbol inspection --- src/lsap/schema/draft/inspect.py | 166 +++++++++++++++++++++++++++++++ 1 file changed, 166 insertions(+) create mode 100644 src/lsap/schema/draft/inspect.py diff --git a/src/lsap/schema/draft/inspect.py b/src/lsap/schema/draft/inspect.py new file mode 100644 index 0000000..b379a33 --- /dev/null +++ b/src/lsap/schema/draft/inspect.py @@ -0,0 +1,166 @@ +""" +# Inspect API + +The Inspect API provides "how to use" information for a symbol, including usage examples, +signatures, and documentation. It is designed to help Agents understand how to correctly +invoke or interact with a symbol. + +## Example Usage + +### Scenario 1: Inspecting a function for usage examples + +Request: + +```json +{ + "locate": { + "file_path": "src/utils.py", + "scope": { + "symbol_path": ["format_date"] + } + }, + "include_examples": 5, + "include_signature": true +} +``` + +### Scenario 2: Inspecting a class with call hierarchy + +Request: + +```json +{ + "locate": { + "file_path": "src/models.py", + "scope": { + "symbol_path": ["User"] + } + }, + "include_call_hierarchy": true +} +``` +""" + +from typing import Final + +from pydantic import BaseModel, ConfigDict, Field + +from .._abc import Response +from ..locate import LocateRequest +from ..models import CallHierarchy, Location, SymbolDetailInfo, SymbolInfo + + +class UsageExample(BaseModel): + """A code snippet showing how a symbol is used in context.""" + + code: str = Field(..., description="Code snippet with context") + context: SymbolInfo | None = Field(None, description="Where this usage occurs") + location: Location = Field(..., description="Exact position of the usage") + call_pattern: str | None = Field( + None, description="Extracted pattern like 'func(arg1, arg2)'" + ) + + +class InspectRequest(LocateRequest): + """ + Request to inspect a symbol for usage-oriented information. + + Provides signatures, documentation, and real-world usage examples from the codebase. + """ + + include_examples: int = Field(default=3, ge=0, le=20) + """Number of usage examples to include.""" + + include_signature: bool = True + """Whether to include the symbol's signature.""" + + include_doc: bool = True + """Whether to include the symbol's documentation (hover).""" + + include_call_hierarchy: bool = False + """Whether to include call hierarchy information.""" + + include_external: bool = False + """Whether to include examples from external libraries if available.""" + + context_lines: int = Field(default=2, ge=0, le=10) + """Number of context lines to include around each example.""" + + +markdown_template: Final = """ +# Inspect: `{{ info.path | join: "." }}` (`{{ info.kind }}`) + +{% if signature != nil -%} +## Signature +```python +{{ signature }} +``` +{%- endif %} + +{% if info.hover != nil -%} +## Documentation +{{ info.hover }} +{%- endif %} + +{% if examples.size > 0 -%} +## Usage Examples +{% for example in examples -%} +### Example {{ forloop.index }} +{% if example.context != nil -%} +In `{{ example.context.path | join: "." }}` (`{{ example.context.kind }}`) at `{{ example.location.file_path }}:{{ example.location.range.start.line }}` +{%- else -%} +At `{{ example.location.file_path }}:{{ example.location.range.start.line }}` +{%- endif %} + +{% if example.call_pattern != nil -%} +Pattern: `{{ example.call_pattern }}` +{%- endif %} + +```{{ example.location.file_path.suffix | remove_first: "." }} +{{ example.code }} +``` +{% endfor -%} +{%- endif %} + +{% if call_hierarchy != nil -%} +{% if call_hierarchy.incoming.size > 0 -%} +## Incoming Calls +{% for item in call_hierarchy.incoming -%} +- `{{ item.name }}` (`{{ item.kind }}`) at `{{ item.file_path }}:{{ item.range.start.line }}` +{% endfor -%} +{%- endif %} + +{% if call_hierarchy.outgoing.size > 0 -%} +## Outgoing Calls +{% for item in call_hierarchy.outgoing -%} +- `{{ item.name }}` (`{{ item.kind }}`) at `{{ item.file_path }}:{{ item.range.start.line }}` +{% endfor -%} +{%- endif %} +{%- endif %} + +--- +> [!TIP] +> Use these examples to understand the expected arguments and common calling patterns for this symbol. +""" + + +class InspectResponse(Response): + """Response containing usage-oriented information about a symbol.""" + + info: SymbolDetailInfo + signature: str | None = None + examples: list[UsageExample] = Field(default_factory=list) + call_hierarchy: CallHierarchy | None = None + + model_config = ConfigDict( + json_schema_extra={ + "markdown": markdown_template, + } + ) + + +__all__ = [ + "UsageExample", + "InspectRequest", + "InspectResponse", +] From 6819b8b2fe64e68eb1b294a04683742481eca311 Mon Sep 17 00:00:00 2001 From: observerw Date: Sun, 25 Jan 2026 02:24:28 +0800 Subject: [PATCH 06/10] test(schema): add tests for Inspect API schema --- tests/test_inspect_schema.py | 192 +++++++++++++++++++++++++++++++++++ 1 file changed, 192 insertions(+) create mode 100644 tests/test_inspect_schema.py diff --git a/tests/test_inspect_schema.py b/tests/test_inspect_schema.py new file mode 100644 index 0000000..0288f0b --- /dev/null +++ b/tests/test_inspect_schema.py @@ -0,0 +1,192 @@ +from pathlib import Path +import pytest +from pydantic import ValidationError +from lsap.schema.draft.inspect import InspectRequest, InspectResponse, UsageExample +from lsap.schema.locate import Locate +from lsap.schema.models import ( + SymbolDetailInfo, + Location, + Position, + Range, + SymbolKind, + SymbolInfo, + CallHierarchy, + CallHierarchyItem, +) + + +def test_inspect_request_defaults(): + """Test InspectRequest default values.""" + req = InspectRequest(locate=Locate(file_path=Path("test.py"), find="func")) + assert req.include_examples == 3 + assert req.include_signature is True + assert req.include_doc is True + assert req.include_call_hierarchy is False + assert req.include_external is False + assert req.context_lines == 2 + + +def test_inspect_request_validation(): + """Test InspectRequest field validation.""" + # Test include_examples range + with pytest.raises(ValidationError): + InspectRequest( + locate=Locate(file_path=Path("test.py"), find="func"), include_examples=-1 + ) + with pytest.raises(ValidationError): + InspectRequest( + locate=Locate(file_path=Path("test.py"), find="func"), include_examples=21 + ) + + # Test context_lines range + with pytest.raises(ValidationError): + InspectRequest( + locate=Locate(file_path=Path("test.py"), find="func"), context_lines=-1 + ) + with pytest.raises(ValidationError): + InspectRequest( + locate=Locate(file_path=Path("test.py"), find="func"), context_lines=11 + ) + + +def test_usage_example_model(): + """Test UsageExample model instantiation.""" + example = UsageExample( + code="func('test')", + location=Location( + file_path=Path("main.py"), + range=Range( + start=Position(line=10, character=5), + end=Position(line=10, character=15), + ), + ), + context=SymbolInfo( + name="caller", + kind=SymbolKind.Function, + file_path=Path("main.py"), + range=Range( + start=Position(line=5, character=1), end=Position(line=15, character=1) + ), + path=["caller"], + ), + call_pattern="func(arg)", + ) + assert example.code == "func('test')" + assert example.call_pattern == "func(arg)" + assert example.context.name == "caller" + + +def test_inspect_response_serialization(): + """Test InspectResponse serialization and deserialization.""" + info = SymbolDetailInfo( + name="my_func", + kind=SymbolKind.Function, + file_path=Path("test.py"), + range=Range( + start=Position(line=1, character=1), end=Position(line=5, character=1) + ), + path=["my_func"], + hover="My function documentation", + ) + + example = UsageExample( + code="my_func()", + location=Location( + file_path=Path("app.py"), + range=Range( + start=Position(line=2, character=1), end=Position(line=2, character=9) + ), + ), + ) + + resp = InspectResponse( + info=info, signature="def my_func() -> None", examples=[example] + ) + + # Serialize to dict + data = resp.model_dump() + assert data["info"]["name"] == "my_func" + assert len(data["examples"]) == 1 + assert data["signature"] == "def my_func() -> None" + + # Deserialize back + resp2 = InspectResponse.model_validate(data) + assert resp2.info.name == "my_func" + assert resp2.signature == "def my_func() -> None" + assert len(resp2.examples) == 1 + + +def test_inspect_response_markdown_rendering(): + """Test InspectResponse.format('markdown') renders correctly.""" + info = SymbolDetailInfo( + name="my_func", + kind=SymbolKind.Function, + file_path=Path("test.py"), + range=Range( + start=Position(line=1, character=1), end=Position(line=5, character=1) + ), + path=["my_func"], + hover="My function documentation", + ) + + example = UsageExample( + code="my_func()", + location=Location( + file_path=Path("app.py"), + range=Range( + start=Position(line=2, character=1), end=Position(line=2, character=9) + ), + ), + context=SymbolInfo( + name="main", + kind=SymbolKind.Function, + file_path=Path("app.py"), + range=Range( + start=Position(line=1, character=1), end=Position(line=10, character=1) + ), + path=["main"], + ), + call_pattern="my_func()", + ) + + call_hierarchy = CallHierarchy( + incoming=[ + CallHierarchyItem( + name="caller_func", + kind=SymbolKind.Function, + file_path=Path("caller.py"), + range=Range( + start=Position(line=5, character=1), + end=Position(line=5, character=10), + ), + selection_range=Range( + start=Position(line=5, character=1), + end=Position(line=5, character=10), + ), + ) + ], + outgoing=[], + ) + + resp = InspectResponse( + info=info, + signature="def my_func() -> None", + examples=[example], + call_hierarchy=call_hierarchy, + ) + + markdown = resp.format("markdown") + + assert "# Inspect: `my_func` (`function`)" in markdown + assert "## Signature" in markdown + assert "def my_func() -> None" in markdown + assert "## Documentation" in markdown + assert "My function documentation" in markdown + assert "## Usage Examples" in markdown + assert "### Example 1" in markdown + assert "In `main` (`function`)" in markdown + assert "Pattern: `my_func()`" in markdown + assert "my_func()" in markdown + assert "## Incoming Calls" in markdown + assert "- `caller_func` (`function`)" in markdown + assert "Use these examples to understand" in markdown From e2ff5362688a3ffe7768e8b702dd31962332bf4c Mon Sep 17 00:00:00 2001 From: observerw Date: Sun, 25 Jan 2026 02:27:16 +0800 Subject: [PATCH 07/10] feat(schema): add Explore API schema for relationship-oriented code exploration --- src/lsap/schema/draft/explore.py | 182 +++++++++++++++++++++++++++++++ 1 file changed, 182 insertions(+) create mode 100644 src/lsap/schema/draft/explore.py diff --git a/src/lsap/schema/draft/explore.py b/src/lsap/schema/draft/explore.py new file mode 100644 index 0000000..33f7678 --- /dev/null +++ b/src/lsap/schema/draft/explore.py @@ -0,0 +1,182 @@ +""" +# Explore API + +The Explore API provides relationship-oriented code exploration, answering "what's around this symbol" +(siblings, dependencies, hierarchy) rather than just "what it is" (definition). + +## Example Usage + +### Scenario 1: Exploring siblings and dependencies of a class + +Request: + +```json +{ + "locate": { + "file_path": "src/models.py", + "scope": { + "symbol_path": ["User"] + } + }, + "include": ["siblings", "dependencies"], + "max_items": 10 +} +``` + +### Scenario 2: Exploring class hierarchy and calls + +Request: + +```json +{ + "locate": { + "file_path": "src/services.py", + "scope": { + "symbol_path": ["AuthService"] + } + }, + "include": ["hierarchy", "calls"], + "resolve_info": true +} +``` +""" + +from typing import Final, Literal + +from pydantic import BaseModel, ConfigDict, Field + +from .._abc import Response +from ..locate import LocateRequest +from ..models import CallHierarchy, SymbolInfo + + +class HierarchyInfo(BaseModel): + """Information about the inheritance hierarchy of a symbol.""" + + parents: list[SymbolInfo] = Field( + default_factory=list, description="Parent classes or interfaces" + ) + children: list[SymbolInfo] = Field( + default_factory=list, description="Child classes or implementations" + ) + + +class ExploreRequest(LocateRequest): + """ + Request to explore relationships around a symbol. + + Provides information about siblings, dependencies, hierarchy, and calls. + """ + + include: list[ + Literal["siblings", "dependencies", "dependents", "hierarchy", "calls"] + ] = Field(default=["siblings", "dependencies"]) + """Types of relationships to include in the exploration.""" + + max_items: int = Field(default=10, ge=1, le=50) + """Maximum number of items to return for each relationship type.""" + + resolve_info: bool = False + """Whether to resolve detailed information for symbols.""" + + include_external: bool = False + """Whether to include external dependencies if available.""" + + +markdown_template: Final = """ +# Explore: `{{ target.path | join: "." }}` (`{{ target.kind }}`) + +{% if siblings.size > 0 -%} +## Siblings +{% for item in siblings -%} +- `{{ item.name }}` (`{{ item.kind }}`) {% if item.range != nil %}at line {{ item.range.start.line | plus: 1 }}{% endif %} +{% endfor -%} +{%- endif %} + +{% if dependencies.size > 0 -%} +## Dependencies +{% for item in dependencies -%} +- `{{ item.name }}` (`{{ item.kind }}`) in `{{ item.file_path }}` +{% endfor -%} +{%- endif %} + +{% if dependents.size > 0 -%} +## Dependents +{% for item in dependents -%} +- `{{ item.name }}` (`{{ item.kind }}`) in `{{ item.file_path }}` +{% endfor -%} +{%- endif %} + +{% if hierarchy != nil -%} +## Hierarchy +{% if hierarchy.parents.size > 0 -%} +### Parents +{% for item in hierarchy.parents -%} +- `{{ item.name }}` (`{{ item.kind }}`) in `{{ item.file_path }}` +{% endfor -%} +{%- endif %} + +{% if hierarchy.children.size > 0 -%} +### Children +{% for item in hierarchy.children -%} +- `{{ item.name }}` (`{{ item.kind }}`) in `{{ item.file_path }}` +{% endfor -%} +{%- endif %} +{%- endif %} + +{% if calls != nil -%} +## Call Hierarchy +{% if calls.incoming.size > 0 -%} +### Incoming Calls +{% for item in calls.incoming -%} +- `{{ item.name }}` (`{{ item.kind }}`) at `{{ item.file_path }}:{{ item.range.start.line | plus: 1 }}` +{% endfor -%} +{%- endif %} + +{% if calls.outgoing.size > 0 -%} +### Outgoing Calls +{% for item in calls.outgoing -%} +- `{{ item.name }}` (`{{ item.kind }}`) at `{{ item.file_path }}:{{ item.range.start.line | plus: 1 }}` +{% endfor -%} +{%- endif %} +{%- endif %} + +--- +> [!TIP] +> Use this map to understand the architectural context and impact of changes to this symbol. +""" + + +class ExploreResponse(Response): + """Response containing relationship-oriented information about a symbol.""" + + target: SymbolInfo + """The symbol being explored.""" + + siblings: list[SymbolInfo] = Field(default_factory=list) + """Symbols defined in the same scope or file.""" + + dependencies: list[SymbolInfo] = Field(default_factory=list) + """Symbols that this symbol depends on.""" + + dependents: list[SymbolInfo] = Field(default_factory=list) + """Symbols that depend on this symbol.""" + + hierarchy: HierarchyInfo | None = None + """Inheritance hierarchy information.""" + + calls: CallHierarchy | None = None + """Call hierarchy information.""" + + model_config = ConfigDict( + json_schema_extra={ + "markdown": markdown_template, + } + ) + + +__all__ = [ + "HierarchyInfo", + "ExploreRequest", + "ExploreResponse", +] From a2f0403cdc20a1e13b08659efb612704b8acd2df Mon Sep 17 00:00:00 2001 From: observerw Date: Sun, 25 Jan 2026 02:30:13 +0800 Subject: [PATCH 08/10] test(schema): add tests for Explore API schema --- tests/test_explore_schema.py | 239 +++++++++++++++++++++++++++++++++++ 1 file changed, 239 insertions(+) create mode 100644 tests/test_explore_schema.py diff --git a/tests/test_explore_schema.py b/tests/test_explore_schema.py new file mode 100644 index 0000000..93b52dd --- /dev/null +++ b/tests/test_explore_schema.py @@ -0,0 +1,239 @@ +from pathlib import Path +import pytest +from pydantic import ValidationError +from lsap.schema.draft.explore import ExploreRequest, ExploreResponse, HierarchyInfo +from lsap.schema.locate import Locate +from lsap.schema.models import ( + SymbolInfo, + CallHierarchy, + CallHierarchyItem, + SymbolKind, + Position, + Range, +) + + +def test_explore_request_defaults(): + """Test ExploreRequest default values.""" + req = ExploreRequest(locate=Locate(file_path=Path("test.py"), find="MyClass")) + assert req.include == ["siblings", "dependencies"] + assert req.max_items == 10 + assert req.resolve_info is False + assert req.include_external is False + + +def test_explore_request_validation(): + """Test ExploreRequest field validation.""" + # Test max_items range (1-50) + with pytest.raises(ValidationError): + ExploreRequest( + locate=Locate(file_path=Path("test.py"), find="MyClass"), max_items=0 + ) + with pytest.raises(ValidationError): + ExploreRequest( + locate=Locate(file_path=Path("test.py"), find="MyClass"), max_items=51 + ) + + +def test_hierarchy_info_model(): + """Test HierarchyInfo model with parents and children.""" + parent = SymbolInfo( + name="Base", + kind=SymbolKind.Class, + file_path=Path("base.py"), + range=Range( + start=Position(line=1, character=1), end=Position(line=10, character=1) + ), + path=["Base"], + ) + child = SymbolInfo( + name="Sub", + kind=SymbolKind.Class, + file_path=Path("sub.py"), + range=Range( + start=Position(line=1, character=1), end=Position(line=10, character=1) + ), + path=["Sub"], + ) + info = HierarchyInfo(parents=[parent], children=[child]) + assert len(info.parents) == 1 + assert len(info.children) == 1 + assert info.parents[0].name == "Base" + assert info.children[0].name == "Sub" + + +def test_explore_response_serialization(): + """Test ExploreResponse serialization and deserialization.""" + target = SymbolInfo( + name="MyClass", + kind=SymbolKind.Class, + file_path=Path("test.py"), + range=Range( + start=Position(line=5, character=1), end=Position(line=15, character=1) + ), + path=["MyClass"], + ) + + sibling = SymbolInfo( + name="OtherClass", + kind=SymbolKind.Class, + file_path=Path("test.py"), + range=Range( + start=Position(line=20, character=1), end=Position(line=30, character=1) + ), + path=["OtherClass"], + ) + + resp = ExploreResponse( + target=target, + siblings=[sibling], + dependencies=[], + dependents=[], + hierarchy=None, + calls=None, + ) + + # Serialize to dict + data = resp.model_dump() + assert data["target"]["name"] == "MyClass" + assert len(data["siblings"]) == 1 + assert data["siblings"][0]["name"] == "OtherClass" + + # Deserialize back + resp2 = ExploreResponse.model_validate(data) + assert resp2.target.name == "MyClass" + assert len(resp2.siblings) == 1 + assert resp2.siblings[0].name == "OtherClass" + + +def test_explore_response_markdown_rendering(): + """Test ExploreResponse.format('markdown') renders correctly.""" + target = SymbolInfo( + name="MyClass", + kind=SymbolKind.Class, + file_path=Path("test.py"), + range=Range( + start=Position(line=5, character=1), end=Position(line=15, character=1) + ), + path=["MyClass"], + ) + + sibling = SymbolInfo( + name="OtherClass", + kind=SymbolKind.Class, + file_path=Path("test.py"), + range=Range( + start=Position(line=20, character=1), end=Position(line=30, character=1) + ), + path=["OtherClass"], + ) + + dependency = SymbolInfo( + name="Helper", + kind=SymbolKind.Class, + file_path=Path("utils.py"), + range=Range( + start=Position(line=1, character=1), end=Position(line=5, character=1) + ), + path=["Helper"], + ) + + dependent = SymbolInfo( + name="Main", + kind=SymbolKind.Class, + file_path=Path("main.py"), + range=Range( + start=Position(line=10, character=1), end=Position(line=20, character=1) + ), + path=["Main"], + ) + + hierarchy = HierarchyInfo( + parents=[ + SymbolInfo( + name="BaseClass", + kind=SymbolKind.Class, + file_path=Path("base.py"), + range=Range( + start=Position(line=1, character=1), + end=Position(line=10, character=1), + ), + path=["BaseClass"], + ) + ], + children=[ + SymbolInfo( + name="SubClass", + kind=SymbolKind.Class, + file_path=Path("sub.py"), + range=Range( + start=Position(line=1, character=1), + end=Position(line=10, character=1), + ), + path=["SubClass"], + ) + ], + ) + + calls = CallHierarchy( + incoming=[ + CallHierarchyItem( + name="caller_func", + kind=SymbolKind.Function, + file_path=Path("caller.py"), + range=Range( + start=Position(line=5, character=1), + end=Position(line=5, character=10), + ), + selection_range=Range( + start=Position(line=5, character=1), + end=Position(line=5, character=10), + ), + ) + ], + outgoing=[ + CallHierarchyItem( + name="callee_func", + kind=SymbolKind.Function, + file_path=Path("callee.py"), + range=Range( + start=Position(line=10, character=1), + end=Position(line=10, character=10), + ), + selection_range=Range( + start=Position(line=10, character=1), + end=Position(line=10, character=10), + ), + ) + ], + ) + + resp = ExploreResponse( + target=target, + siblings=[sibling], + dependencies=[dependency], + dependents=[dependent], + hierarchy=hierarchy, + calls=calls, + ) + + markdown = resp.format("markdown") + + assert "# Explore: `MyClass` (`class`)" in markdown + assert "## Siblings" in markdown + assert "- `OtherClass` (`class`) at line 21" in markdown + assert "## Dependencies" in markdown + assert "- `Helper` (`class`) in `utils.py`" in markdown + assert "## Dependents" in markdown + assert "- `Main` (`class`) in `main.py`" in markdown + assert "## Hierarchy" in markdown + assert "### Parents" in markdown + assert "- `BaseClass` (`class`) in `base.py`" in markdown + assert "### Children" in markdown + assert "- `SubClass` (`class`) in `sub.py`" in markdown + assert "## Call Hierarchy" in markdown + assert "### Incoming Calls" in markdown + assert "- `caller_func` (`function`) at `caller.py:6`" in markdown + assert "### Outgoing Calls" in markdown + assert "- `callee_func` (`function`) at `callee.py:11`" in markdown + assert "Use this map to understand" in markdown From b24599c4f3d1384b2f8bf8e7c489ec16d017bfdf Mon Sep 17 00:00:00 2001 From: observerw Date: Sun, 25 Jan 2026 02:33:21 +0800 Subject: [PATCH 09/10] docs(schema): add documentation for Inspect and Explore APIs - Add inspect.md: documents usage-oriented symbol inspection API - Add explore.md: documents relationship-oriented code exploration API - Both docs include overview, use cases, schema, and examples --- schema/explore.md | 91 +++++++++++++++++++++++++++++++++++++++++++++++ schema/inspect.md | 89 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 180 insertions(+) create mode 100644 schema/explore.md create mode 100644 schema/inspect.md diff --git a/schema/explore.md b/schema/explore.md new file mode 100644 index 0000000..d030b2d --- /dev/null +++ b/schema/explore.md @@ -0,0 +1,91 @@ +# Explore API + +## Overview +The Explore API provides "what's around" information for a specific code element. It builds a relationship map of the symbol's neighborhood, including its siblings, dependencies, dependents, and hierarchical position. This API is designed to help Agents build a mental map of the code structure and discover related architectural context. + +## When to Use +- **Building a mental map**: When an Agent first encounters a new class or module and needs to understand its role in the system. +- **Impact analysis**: When an Agent wants to see what other components depend on a specific class before making changes. +- **Discovering related code**: When an Agent is looking for similar implementations or helper classes in the same neighborhood. +- **Architectural context**: When an Agent needs to understand the "big picture" of how a component fits into the overall project structure. + +## Key Differences from Hierarchy API +| Feature | Explore API | Hierarchy API | +|---------|-------------|---------------| +| **Scope** | Neighborhood & Relationships | Deep Tree Traversal | +| **Direction** | Multi-directional (Up, Down, Sideways) | Vertical (Parent/Child) | +| **Goal** | Contextual Mapping | Structural Navigation | +| **Content** | Siblings, Callers, Callees, Types | Call/Type Hierarchy Tree | + +## Relationship Types +- **Hierarchy**: The parent/child relationship (e.g., class members, module contents). +- **Calls**: Outgoing calls (dependencies) and incoming calls (dependents). +- **Types**: Type relationships (e.g., base classes, interface implementations). +- **Siblings**: Other symbols defined in the same scope or file. + +## Request Schema +- `locate`: The target symbol to explore. +- `depth`: (Integer) How many levels of relationships to traverse (default: 1). +- `relationship_types`: (List of Strings) Which types of relationships to include (e.g., `["calls", "hierarchy", "siblings"]`). + +## Response Schema +- `center`: The symbol at the center of the exploration. +- `relationships`: A structured map of related symbols grouped by relationship type. +- `summary`: A high-level description of the symbol's neighborhood. + +## Example: Exploring a class +**Request:** +```json +{ + "locate": { + "file_path": "src/models/user.py", + "scope": { + "symbol_path": ["User"] + } + }, + "relationship_types": ["hierarchy", "siblings", "calls"], + "depth": 1 +} +``` + +**Response:** +```json +{ + "center": { "name": "User", "kind": "class" }, + "relationships": { + "hierarchy": [ + { "name": "User.validate", "kind": "method" }, + { "name": "User.save", "kind": "method" } + ], + "siblings": [ + { "name": "UserRole", "kind": "enum" }, + { "name": "AnonymousUser", "kind": "class" } + ], + "dependents": [ + { "name": "AuthService", "kind": "class", "file": "src/services/auth.py" } + ] + } +} +``` + +**Markdown Output:** +```markdown +# Explore: `User` (class) + +`User` is a central model in `src/models/user.py`, primarily used by `AuthService`. + +## Hierarchy (Members) +- `validate` (method) +- `save` (method) + +## Siblings (In same file) +- `UserRole` (enum) +- `AnonymousUser` (class) + +## Dependents (Used by) +- `AuthService` (src/services/auth.py) +``` + +## See Also +- [Outline API](./outline.md) +- [Reference API](./reference.md) diff --git a/schema/inspect.md b/schema/inspect.md new file mode 100644 index 0000000..6c26cc8 --- /dev/null +++ b/schema/inspect.md @@ -0,0 +1,89 @@ +# Inspect API + +## Overview +The Inspect API provides "how to use" information for a specific symbol. While the Symbol API focuses on "what it is" (implementation details), the Inspect API is designed to help Agents understand how to correctly call or interact with a symbol by providing its signature, documentation, and real-world usage examples. + +## When to Use +- **Learning to call an API**: When an Agent needs to know the parameters, types, and return values of a function. +- **Understanding usage patterns**: When an Agent wants to see how other parts of the codebase call a specific method to avoid common mistakes. +- **Contextual documentation**: When the Agent needs a high-level summary of a symbol's purpose without reading the entire implementation. + +## Key Differences from Symbol API +| Feature | Inspect API | Symbol API | +|---------|-------------|------------| +| **Primary Focus** | How to use the symbol | What the symbol is/does | +| **Content** | Signature, Docstring, Usage Examples | Full Source Code, Implementation | +| **Goal** | Integration & Calling | Understanding Implementation | +| **Context** | External perspective | Internal perspective | + +## Request Schema +- `locate`: The target symbol to inspect (supports `file_path`, `symbol_path`, `find`, etc.). +- `include_usage`: (Boolean) Whether to include real-world usage examples from the codebase. +- `max_usage_examples`: (Integer) Maximum number of usage examples to return. + +## Response Schema +- `symbol`: Basic information about the symbol (name, kind, location). +- `signature`: The formal signature of the symbol (e.g., function parameters and return types). +- `documentation`: The docstring or comments associated with the symbol. +- `usages`: A list of code snippets showing how the symbol is used elsewhere. + +## Example: Inspecting a function +**Request:** +```json +{ + "locate": { + "file_path": "src/utils/auth.py", + "scope": { + "symbol_path": ["verify_token"] + } + }, + "include_usage": true, + "max_usage_examples": 2 +} +``` + +**Response:** +```json +{ + "symbol": { + "name": "verify_token", + "kind": "function", + "location": { "uri": "file:///src/utils/auth.py", "range": { ... } } + }, + "signature": "def verify_token(token: str, secret: str = None) -> UserPayload", + "documentation": "Verifies a JWT token and returns the decoded payload.\n\n:param token: The JWT string to verify.\n:param secret: Optional secret override.", + "usages": [ + { + "file_path": "src/api/middleware.py", + "line": 42, + "code": "payload = verify_token(auth_header.split(' ')[1])" + } + ] +} +``` + +**Markdown Output:** +```markdown +# Inspect: `verify_token` + +**Signature:** `def verify_token(token: str, secret: str = None) -> UserPayload` + +--- + +Verifies a JWT token and returns the decoded payload. + +:param token: The JWT string to verify. +:param secret: Optional secret override. + +## Usage Examples + +### src/api/middleware.py:42 +```python +41 | auth_header = request.headers.get('Authorization') +42 | payload = verify_token(auth_header.split(' ')[1]) +43 | request.user = payload +``` + +## See Also +- [Symbol API](./symbol.md) +- [Reference API](./reference.md) From 3c710543f7e31ebbdfe574766af574c8ddb45d87 Mon Sep 17 00:00:00 2001 From: observerw Date: Sun, 25 Jan 2026 22:40:04 +0800 Subject: [PATCH 10/10] fix: test error --- src/lsap/capability/relation.py | 14 +- src/lsap/schema/draft/explore.py | 2 +- src/lsap/schema/draft/inspect.py | 2 +- tests/__init__.py | 0 tests/framework/__init__.py | 0 tests/framework/lsp.py | 595 +++++++++++++++++++++++++++++++ tests/test_explore_schema.py | 14 +- tests/test_inspect_schema.py | 15 +- tests/test_relation.py | 17 +- tests/test_relation_e2e.py | 535 +++++++++++++++++++++++++++ tests/test_rename_glob.py | 3 +- 11 files changed, 1169 insertions(+), 28 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/framework/__init__.py create mode 100644 tests/framework/lsp.py create mode 100644 tests/test_relation_e2e.py diff --git a/src/lsap/capability/relation.py b/src/lsap/capability/relation.py index ef54f70..8e8b8c2 100644 --- a/src/lsap/capability/relation.py +++ b/src/lsap/capability/relation.py @@ -28,14 +28,24 @@ def locate(self) -> LocateCapability: @override async def __call__(self, req: RelationRequest) -> RelationResponse | None: + from lsap.exception import NotFoundError + # Resolve source symbol source_req = LocateRequest(locate=req.source) - if not (source_loc := await self.locate(source_req)): + try: + source_loc = await self.locate(source_req) + if not source_loc: + return None + except NotFoundError: return None # Resolve target symbol target_req = LocateRequest(locate=req.target) - if not (target_loc := await self.locate(target_req)): + try: + target_loc = await self.locate(target_req) + if not target_loc: + return None + except NotFoundError: return None # Get CallHierarchyItems for source and target diff --git a/src/lsap/schema/draft/explore.py b/src/lsap/schema/draft/explore.py index 33f7678..ebd299b 100644 --- a/src/lsap/schema/draft/explore.py +++ b/src/lsap/schema/draft/explore.py @@ -176,7 +176,7 @@ class ExploreResponse(Response): __all__ = [ - "HierarchyInfo", "ExploreRequest", "ExploreResponse", + "HierarchyInfo", ] diff --git a/src/lsap/schema/draft/inspect.py b/src/lsap/schema/draft/inspect.py index b379a33..a53a161 100644 --- a/src/lsap/schema/draft/inspect.py +++ b/src/lsap/schema/draft/inspect.py @@ -160,7 +160,7 @@ class InspectResponse(Response): __all__ = [ - "UsageExample", "InspectRequest", "InspectResponse", + "UsageExample", ] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/framework/__init__.py b/tests/framework/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/framework/lsp.py b/tests/framework/lsp.py new file mode 100644 index 0000000..1050228 --- /dev/null +++ b/tests/framework/lsp.py @@ -0,0 +1,595 @@ +from __future__ import annotations + +import tempfile +from collections.abc import AsyncGenerator, Sequence +from contextlib import asynccontextmanager +from pathlib import Path +from typing import Any + +import attrs +from lsp_client.capability.request.completion import WithRequestCompletion +from lsp_client.capability.request.definition import WithRequestDefinition +from lsp_client.capability.request.document_symbol import WithRequestDocumentSymbol +from lsp_client.capability.request.hover import WithRequestHover +from lsp_client.capability.request.reference import WithRequestReferences +from lsp_client.client.abc import Client +from lsp_client.utils.types import Position, Range, lsp_type + +type LspResponse[R] = R | None + + +@attrs.define +class LspInteraction[C: Client]: + client: C + workspace_root: Path + + @property + def resolved_workspace(self) -> Path: + """Get resolved workspace path for path comparisons.""" + return self.workspace_root.resolve() + + def full_path(self, relative_path: str) -> Path: + # Return the original path (not resolved) for file operations + # This is important for symlinked fixtures + return self.workspace_root / relative_path + + def full_path_resolved(self, relative_path: str) -> Path: + """Return resolved path for comparison with pyrefly responses.""" + return (self.workspace_root / relative_path).resolve() + + async def create_file(self, relative_path: str, content: str) -> Path: + path = self.full_path(relative_path) + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(content) + return path + + async def request_definition( + self, relative_path: str, line: int, column: int + ) -> DefinitionAssertion: + assert isinstance(self.client, WithRequestDefinition) + path = self.full_path(relative_path) + resp = await self.client.request_definition( + file_path=path, + position=Position(line=line, character=column), + ) + return DefinitionAssertion(self, resp) + + async def request_hover( + self, relative_path: str, line: int, column: int + ) -> HoverAssertion: + assert isinstance(self.client, WithRequestHover) + path = self.full_path(relative_path) + resp = await self.client.request_hover( + file_path=path, + position=Position(line=line, character=column), + ) + return HoverAssertion(self, resp) + + async def request_completion( + self, relative_path: str, line: int, column: int + ) -> CompletionAssertion: + assert isinstance(self.client, WithRequestCompletion) + path = self.full_path(relative_path) + resp = await self.client.request_completion( + file_path=path, + position=Position(line=line, character=column), + ) + return CompletionAssertion(self, resp) + + async def request_references( + self, relative_path: str, line: int, column: int + ) -> ReferencesAssertion: + assert isinstance(self.client, WithRequestReferences) + path = self.full_path(relative_path) + resp = await self.client.request_references( + file_path=path, + position=Position(line=line, character=column), + ) + return ReferencesAssertion(self, resp) + + async def request_document_symbols( + self, relative_path: str + ) -> DocumentSymbolsAssertion: + assert isinstance(self.client, WithRequestDocumentSymbol) + path = self.full_path(relative_path) + resp = await self.client.request_document_symbol(file_path=path) + return DocumentSymbolsAssertion(self, resp) + + +@attrs.define +class DefinitionAssertion: + interaction: LspInteraction[Any] + response: ( + lsp_type.Location + | Sequence[lsp_type.Location] + | Sequence[lsp_type.LocationLink] + | None + ) + + def expect_definition( + self, + relative_path: str, + start_line: int, + start_col: int, + end_line: int, + end_col: int, + ) -> None: + assert self.response is not None, "Definition response is None" + + # Use resolved path for comparison with pyrefly responses + expected_path = self.interaction.full_path_resolved(relative_path) + expected_range = Range( + start=Position(line=start_line, character=start_col), + end=Position(line=end_line, character=end_col), + ) + + match self.response: + case lsp_type.Location() as loc: + actual_path = Path( + self.interaction.client.from_uri(loc.uri, relative=False) + ) + # Compare using resolved paths to handle symlinks properly + actual_resolved = actual_path.resolve() + expected_resolved = expected_path.resolve() + assert actual_resolved == expected_resolved, ( + f"Expected resolved path {expected_resolved}, got {actual_resolved}" + ) + assert loc.range == expected_range + case list() | Sequence() as locs: + found = False + for loc in locs: + if isinstance(loc, lsp_type.Location): + actual_path = Path( + self.interaction.client.from_uri(loc.uri, relative=False) + ) + actual_range = loc.range + elif isinstance(loc, lsp_type.LocationLink): + actual_path = Path( + self.interaction.client.from_uri( + loc.target_uri, relative=False + ) + ) + actual_range = loc.target_selection_range + else: + continue + + # Compare using resolved paths to handle symlinks + actual_resolved = actual_path.resolve() + expected_resolved = expected_path.resolve() + path_match = actual_resolved == expected_resolved + + if path_match and actual_range == expected_range: + found = True + break + + assert found, ( + f"Definition not found at {expected_path}:{expected_range}" + ) + case _: + raise TypeError( + f"Unexpected definition response type: {type(self.response)}" + ) + + match self.response: + case lsp_type.Location() as loc: + actual_path = Path( + self.interaction.client.from_uri(loc.uri, relative=False) + ) + # Pyrefly may return resolved paths + # Compare using resolved paths to handle symlinks properly + actual_resolved = actual_path.resolve() + expected_resolved = expected_path.resolve() + assert actual_resolved == expected_resolved, ( + f"Expected resolved path {expected_resolved}, got {actual_resolved}" + ) + assert loc.range == expected_range + case list() | Sequence() as locs: + found = False + for loc in locs: + if isinstance(loc, lsp_type.Location): + actual_path = Path( + self.interaction.client.from_uri(loc.uri, relative=False) + ) + actual_range = loc.range + elif isinstance(loc, lsp_type.LocationLink): + actual_path = Path( + self.interaction.client.from_uri( + loc.target_uri, relative=False + ) + ) + actual_range = loc.target_selection_range + else: + continue + + # Compare using resolved paths to handle symlinks + actual_resolved = actual_path.resolve() + expected_resolved = expected_path.resolve() + path_match = actual_resolved == expected_resolved + + if path_match and actual_range == expected_range: + found = True + break + + assert found, ( + f"Definition not found at {expected_path}:{expected_range}" + ) + case _: + raise TypeError( + f"Unexpected definition response type: {type(self.response)}" + ) + + match self.response: + case lsp_type.Location() as loc: + actual_path = Path( + self.interaction.client.from_uri(loc.uri, relative=False) + ) + # Pyrefly may return resolved paths, so we need to handle symlinks + # Try to resolve both paths and compare + try: + actual_resolved = actual_path.resolve() + expected_resolved = expected_path.resolve() + assert actual_resolved == expected_resolved, ( + f"Expected resolved path {expected_resolved}, got {actual_resolved}" + ) + except ValueError: + # Fallback to relative path comparison if resolve fails + actual_rel = actual_path.relative_to( + self.interaction.workspace_root + ) + expected_rel = expected_path.relative_to( + self.interaction.workspace_root + ) + assert actual_rel == expected_rel, ( + f"Expected path {expected_rel}, got {actual_rel}" + ) + assert loc.range == expected_range + case list() | Sequence() as locs: + found = False + for loc in locs: + if isinstance(loc, lsp_type.Location): + actual_path = Path( + self.interaction.client.from_uri(loc.uri, relative=False) + ) + actual_range = loc.range + elif isinstance(loc, lsp_type.LocationLink): + actual_path = Path( + self.interaction.client.from_uri( + loc.target_uri, relative=False + ) + ) + actual_range = loc.target_selection_range + else: + continue + + # Pyrefly may return resolved paths, so we need to handle symlinks + try: + actual_resolved = actual_path.resolve() + expected_resolved = expected_path.resolve() + path_match = actual_resolved == expected_resolved + except ValueError: + # Fallback to relative path comparison + try: + actual_rel = actual_path.relative_to( + self.interaction.workspace_root + ) + expected_rel = expected_path.relative_to( + self.interaction.workspace_root + ) + path_match = actual_rel == expected_rel + except ValueError: + path_match = False + + if path_match and actual_range == expected_range: + found = True + break + + assert found, ( + f"Definition not found at {expected_path}:{expected_range}" + ) + case _: + raise TypeError( + f"Unexpected definition response type: {type(self.response)}" + ) + + match self.response: + case lsp_type.Location() as loc: + actual_path = Path( + self.interaction.client.from_uri(loc.uri, relative=False) + ) + # Compare using relative paths to handle symlinks + try: + actual_rel = actual_path.relative_to( + self.interaction.workspace_root + ) + expected_rel = expected_path.relative_to( + self.interaction.workspace_root + ) + assert actual_rel == expected_rel, ( + f"Expected path {expected_rel}, got {actual_rel}" + ) + except ValueError: + # Fallback to absolute path comparison + assert actual_path.resolve() == expected_path.resolve(), ( + f"Expected path {expected_path}, got {actual_path}" + ) + assert loc.range == expected_range + case list() | Sequence() as locs: + found = False + for loc in locs: + if isinstance(loc, lsp_type.Location): + actual_path = Path( + self.interaction.client.from_uri(loc.uri, relative=False) + ) + actual_range = loc.range + elif isinstance(loc, lsp_type.LocationLink): + actual_path = Path( + self.interaction.client.from_uri( + loc.target_uri, relative=False + ) + ) + actual_range = loc.target_selection_range + else: + continue + + # Compare using relative paths to handle symlinks + try: + actual_rel = actual_path.relative_to( + self.interaction.workspace_root + ) + expected_rel = expected_path.relative_to( + self.interaction.workspace_root + ) + path_match = actual_rel == expected_rel + except ValueError: + # Fallback to absolute path comparison + path_match = actual_path.resolve() == expected_path.resolve() + + if path_match and actual_range == expected_range: + found = True + break + + assert found, ( + f"Definition not found at {expected_path}:{expected_range}" + ) + case _: + raise TypeError( + f"Unexpected definition response type: {type(self.response)}" + ) + + match self.response: + case lsp_type.Location() as loc: + actual_path = Path( + self.interaction.client.from_uri(loc.uri, relative=False) + ) + # Compare using relative paths to handle symlinks + try: + actual_rel = actual_path.relative_to( + self.interaction.workspace_root + ) + expected_rel = expected_path.relative_to( + self.interaction.workspace_root + ) + assert actual_rel == expected_rel, ( + f"Expected path {expected_rel}, got {actual_rel}" + ) + except ValueError: + # Fallback to absolute path comparison + assert actual_path.resolve() == expected_path.resolve(), ( + f"Expected path {expected_path}, got {actual_path}" + ) + assert loc.range == expected_range + case list() | Sequence() as locs: + found = False + for loc in locs: + if isinstance(loc, lsp_type.Location): + actual_path = Path( + self.interaction.client.from_uri(loc.uri, relative=False) + ) + actual_range = loc.range + elif isinstance(loc, lsp_type.LocationLink): + actual_path = Path( + self.interaction.client.from_uri( + loc.target_uri, relative=False + ) + ) + actual_range = loc.target_selection_range + else: + continue + + # Compare using relative paths to handle symlinks + try: + actual_rel = actual_path.relative_to( + self.interaction.workspace_root + ) + expected_rel = expected_path.relative_to( + self.interaction.workspace_root + ) + path_match = actual_rel == expected_rel + except ValueError: + # Fallback to absolute path comparison + path_match = actual_path.resolve() == expected_path.resolve() + + if path_match and actual_range == expected_range: + found = True + break + + assert found, ( + f"Definition not found at {expected_path}:{expected_range}" + ) + case _: + raise TypeError( + f"Unexpected definition response type: {type(self.response)}" + ) + + match self.response: + case lsp_type.Location() as loc: + actual_path = Path( + self.interaction.client.from_uri(loc.uri, relative=False) + ) + # Compare using relative paths to handle symlinks + try: + actual_rel = actual_path.relative_to( + self.interaction.workspace_root + ) + expected_rel = expected_path.relative_to( + self.interaction.workspace_root + ) + assert actual_rel == expected_rel, ( + f"Expected path {expected_rel}, got {actual_rel}" + ) + except ValueError: + # Fallback to absolute path comparison + assert actual_path.resolve() == expected_path.resolve(), ( + f"Expected path {expected_path}, got {actual_path}" + ) + assert loc.range == expected_range + case list() | Sequence() as locs: + found = False + for loc in locs: + if isinstance(loc, lsp_type.Location): + actual_path = Path( + self.interaction.client.from_uri(loc.uri, relative=False) + ) + actual_range = loc.range + elif isinstance(loc, lsp_type.LocationLink): + actual_path = Path( + self.interaction.client.from_uri( + loc.target_uri, relative=False + ) + ) + actual_range = loc.target_selection_range + else: + continue + + # Compare using relative paths to handle symlinks + try: + actual_rel = actual_path.relative_to( + self.interaction.workspace_root + ) + expected_rel = expected_path.relative_to( + self.interaction.workspace_root + ) + path_match = actual_rel == expected_rel + except ValueError: + # Fallback to absolute path comparison + path_match = actual_path.resolve() == expected_path.resolve() + + if path_match and actual_range == expected_range: + found = True + break + + assert found, ( + f"Definition not found at {expected_path}:{expected_range}" + ) + case _: + raise TypeError( + f"Unexpected definition response type: {type(self.response)}" + ) + + +@attrs.define +class HoverAssertion: + interaction: LspInteraction[Any] + response: lsp_type.MarkupContent | None + + def expect_content(self, pattern: str) -> None: + assert self.response is not None, "Hover response is None" + assert pattern in self.response.value, ( + f"Expected '{pattern}' in hover content, got '{self.response.value}'" + ) + + +@attrs.define +class CompletionAssertion: + interaction: LspInteraction[Any] + response: Sequence[lsp_type.CompletionItem] + + def expect_label(self, label: str) -> None: + labels = [item.label for item in self.response] + assert label in labels, ( + f"Expected completion label '{label}' not found in {labels}" + ) + + +@attrs.define +class ReferencesAssertion: + interaction: LspInteraction[Any] + response: Sequence[lsp_type.Location] | None + + def expect_reference( + self, + relative_path: str, + start_line: int, + start_col: int, + end_line: int, + end_col: int, + ) -> None: + assert self.response is not None, "References response is None" + expected_path = self.interaction.full_path(relative_path) + expected_range = Range( + start=Position(line=start_line, character=start_col), + end=Position(line=end_line, character=end_col), + ) + + found = False + for loc in self.response: + actual_path = self.interaction.client.from_uri(loc.uri, relative=False) + if ( + Path(actual_path).resolve() == expected_path + and loc.range == expected_range + ): + found = True + break + assert found, f"Reference not found at {expected_path}:{expected_range}" + + +@attrs.define +class DocumentSymbolsAssertion: + interaction: LspInteraction[Any] + response: ( + Sequence[lsp_type.SymbolInformation] | Sequence[lsp_type.DocumentSymbol] | None + ) + _last_found_names: list[str] = attrs.field(factory=list, init=False) + + def expect_symbol(self, name: str, kind: lsp_type.SymbolKind | None = None) -> None: + assert self.response is not None, "Document symbols response is None" + + def check_symbols( + symbols: Sequence[lsp_type.SymbolInformation] + | Sequence[lsp_type.DocumentSymbol], + found_names: list[str], + ) -> bool: + for sym in symbols: + if isinstance(sym, lsp_type.DocumentSymbol): + found_names.append(f"{sym.name} ({sym.kind})") + if sym.name == name and (kind is None or sym.kind == kind): + return True + if sym.children and check_symbols(sym.children, found_names): + return True + elif isinstance(sym, lsp_type.SymbolInformation): + found_names.append(f"{sym.name} ({sym.kind})") + if sym.name == name and (kind is None or sym.kind == kind): + return True + return False + + self._last_found_names = [] + if not check_symbols(self.response, self._last_found_names): + print(f"Actually found: {self._last_found_names}") + raise AssertionError(f"Symbol '{name}' not found") + + +@asynccontextmanager +async def lsp_interaction_context[C: Client]( + client_cls: type[C], workspace_root: Path | None = None, **client_kwargs: Any +) -> AsyncGenerator[LspInteraction[C], None]: + if workspace_root is None: + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir).resolve() + async with client_cls(workspace=root, **client_kwargs) as client: + yield LspInteraction(client=client, workspace_root=root) + else: + # Use original path to preserve symlinks + # This is important for test fixtures that are symlinked + root = workspace_root + async with client_cls(workspace=root, **client_kwargs) as client: + yield LspInteraction(client=client, workspace_root=root) diff --git a/tests/test_explore_schema.py b/tests/test_explore_schema.py index 93b52dd..88a0363 100644 --- a/tests/test_explore_schema.py +++ b/tests/test_explore_schema.py @@ -1,15 +1,17 @@ from pathlib import Path + import pytest from pydantic import ValidationError + from lsap.schema.draft.explore import ExploreRequest, ExploreResponse, HierarchyInfo from lsap.schema.locate import Locate from lsap.schema.models import ( - SymbolInfo, CallHierarchy, CallHierarchyItem, - SymbolKind, Position, Range, + SymbolInfo, + SymbolKind, ) @@ -185,10 +187,6 @@ def test_explore_response_markdown_rendering(): start=Position(line=5, character=1), end=Position(line=5, character=10), ), - selection_range=Range( - start=Position(line=5, character=1), - end=Position(line=5, character=10), - ), ) ], outgoing=[ @@ -200,10 +198,6 @@ def test_explore_response_markdown_rendering(): start=Position(line=10, character=1), end=Position(line=10, character=10), ), - selection_range=Range( - start=Position(line=10, character=1), - end=Position(line=10, character=10), - ), ) ], ) diff --git a/tests/test_inspect_schema.py b/tests/test_inspect_schema.py index 0288f0b..2890895 100644 --- a/tests/test_inspect_schema.py +++ b/tests/test_inspect_schema.py @@ -1,17 +1,19 @@ from pathlib import Path + import pytest from pydantic import ValidationError + from lsap.schema.draft.inspect import InspectRequest, InspectResponse, UsageExample from lsap.schema.locate import Locate from lsap.schema.models import ( - SymbolDetailInfo, + CallHierarchy, + CallHierarchyItem, Location, Position, Range, - SymbolKind, + SymbolDetailInfo, SymbolInfo, - CallHierarchy, - CallHierarchyItem, + SymbolKind, ) @@ -73,6 +75,7 @@ def test_usage_example_model(): ) assert example.code == "func('test')" assert example.call_pattern == "func(arg)" + assert example.context is not None assert example.context.name == "caller" @@ -159,10 +162,6 @@ def test_inspect_response_markdown_rendering(): start=Position(line=5, character=1), end=Position(line=5, character=10), ), - selection_range=Range( - start=Position(line=5, character=1), - end=Position(line=5, character=10), - ), ) ], outgoing=[], diff --git a/tests/test_relation.py b/tests/test_relation.py index 37b6bb2..2995a4e 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -16,6 +16,8 @@ from lsp_client.protocol import CapabilityClientProtocol from lsp_client.protocol.lang import LanguageConfig from lsp_client.utils.config import ConfigurationMap +from lsp_client.utils.types import AnyPath +from lsp_client.utils.types import Position as LSPPosition from lsp_client.utils.workspace import ( DEFAULT_WORKSPACE_DIR, Workspace, @@ -29,7 +31,6 @@ LanguageKind, SymbolKind, ) -from lsprotocol.types import Position as LSPPosition from lsprotocol.types import Range as LSPRange from lsap.capability.relation import RelationCapability @@ -421,7 +422,11 @@ def test_relation_request_with_nested_symbol_path(): max_depth=3, ) + assert req.source.scope is not None + assert isinstance(req.source.scope, SymbolScope) assert req.source.scope.symbol_path == ["UserService", "get_user"] + assert req.target.scope is not None + assert isinstance(req.target.scope, SymbolScope) assert req.target.scope.symbol_path == ["Database", "query"] assert req.max_depth == 3 @@ -492,11 +497,12 @@ async def open_files(self, *file_paths): yield async def request_document_symbol_list( - self, file_path: Path + self, file_path: AnyPath ) -> list[DocumentSymbol]: """Mock document symbol list - returns a single function symbol based on file name.""" + path = Path(file_path) # Extract symbol name from file path (e.g., test_A.py -> A) - name = file_path.stem.replace("test_", "") + name = path.stem.replace("test_", "") if name in self.call_graph or any( name in calls for calls in self.call_graph.values() ): @@ -536,11 +542,12 @@ def _make_call_hierarchy_item(self, name: str) -> CallHierarchyItem: ) async def prepare_call_hierarchy( - self, file_path: Path, position: LSPPosition + self, file_path: AnyPath, position: LSPPosition ) -> list[CallHierarchyItem] | None: """Mock prepare_call_hierarchy - returns item based on file path.""" + path = Path(file_path) # Extract symbol name from file path (e.g., test_A.py -> A) - name = file_path.stem.replace("test_", "") + name = path.stem.replace("test_", "") if name in self.call_graph or any( name in calls for calls in self.call_graph.values() ): diff --git a/tests/test_relation_e2e.py b/tests/test_relation_e2e.py new file mode 100644 index 0000000..8636a00 --- /dev/null +++ b/tests/test_relation_e2e.py @@ -0,0 +1,535 @@ +""" +End-to-end tests for RelationCapability using BasedpyrightClient. + +These tests verify the RelationCapability's ability to find call chains +between functions in real Python code using a real Language Server. +""" + +from __future__ import annotations + +import anyio +import pytest +from lsp_client.clients.basedpyright import BasedpyrightClient + +from lsap.capability.relation import RelationCapability +from lsap.schema.locate import Locate, SymbolScope +from lsap.schema.relation import RelationRequest + +from .framework.lsp import lsp_interaction_context + + +@pytest.mark.e2e +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_relation_direct_call(): + """Test finding a direct call relationship: A -> B.""" + async with lsp_interaction_context(BasedpyrightClient) as interaction: # type: ignore + # Create file with direct call + content = ''' +def caller(): + """The calling function.""" + return callee() + +def callee(): + """The called function.""" + return 42 +''' + file_path = await interaction.create_file("direct_call.py", content) + + # Wait for LSP indexing + await anyio.sleep(1) + + # Create RelationCapability + capability = RelationCapability(client=interaction.client) # type: ignore + + # Request to find path from caller to callee + req = RelationRequest( + source=Locate( + file_path=file_path, + scope=SymbolScope(symbol_path=["caller"]), + ), + target=Locate( + file_path=file_path, + scope=SymbolScope(symbol_path=["callee"]), + ), + max_depth=5, + ) + + resp = await capability(req) + + # Verify response + assert resp is not None, "Response should not be None" + assert resp.source.name == "caller" + assert resp.target.name == "callee" + assert len(resp.chains) == 1, f"Expected 1 chain, got {len(resp.chains)}" + assert len(resp.chains[0]) == 2, "Chain should have 2 nodes: caller -> callee" + assert resp.chains[0][0].name == "caller" + assert resp.chains[0][1].name == "callee" + + # Test markdown formatting + markdown = resp.format() + assert "caller" in markdown + assert "callee" in markdown + assert "Found 1 call chain(s)" in markdown + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_relation_indirect_call(): + """Test finding an indirect call relationship: A -> B -> C.""" + async with lsp_interaction_context(BasedpyrightClient) as interaction: # type: ignore + content = ''' +def entry_point(): + """Entry point function.""" + return middle_layer() + +def middle_layer(): + """Middle layer function.""" + return final_destination() + +def final_destination(): + """Final destination function.""" + return "success" +''' + file_path = await interaction.create_file("indirect_call.py", content) + + await anyio.sleep(1) + + capability = RelationCapability(client=interaction.client) # type: ignore + + req = RelationRequest( + source=Locate( + file_path=file_path, + scope=SymbolScope(symbol_path=["entry_point"]), + ), + target=Locate( + file_path=file_path, + scope=SymbolScope(symbol_path=["final_destination"]), + ), + max_depth=10, + ) + + resp = await capability(req) + + assert resp is not None + assert resp.source.name == "entry_point" + assert resp.target.name == "final_destination" + assert len(resp.chains) == 1 + assert len(resp.chains[0]) == 3, "Chain should have 3 nodes" + assert resp.chains[0][0].name == "entry_point" + assert resp.chains[0][1].name == "middle_layer" + assert resp.chains[0][2].name == "final_destination" + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_relation_multiple_paths(): + """Test finding multiple paths between two functions.""" + async with lsp_interaction_context(BasedpyrightClient) as interaction: # type: ignore + content = ''' +def source(): + """Source function with multiple call paths.""" + path_a() + path_b() + return "done" + +def path_a(): + """First path to target.""" + return target() + +def path_b(): + """Second path to target.""" + return target() + +def target(): + """Target function.""" + return 100 +''' + file_path = await interaction.create_file("multiple_paths.py", content) + + await anyio.sleep(1) + + capability = RelationCapability(client=interaction.client) # type: ignore + + req = RelationRequest( + source=Locate( + file_path=file_path, + scope=SymbolScope(symbol_path=["source"]), + ), + target=Locate( + file_path=file_path, + scope=SymbolScope(symbol_path=["target"]), + ), + max_depth=5, + ) + + resp = await capability(req) + + assert resp is not None + assert resp.source.name == "source" + assert resp.target.name == "target" + assert len(resp.chains) == 2, f"Expected 2 paths, got {len(resp.chains)}" + + # Verify both paths exist + middle_nodes = {chain[1].name for chain in resp.chains} + assert middle_nodes == {"path_a", "path_b"} + + # Verify all chains have correct structure + for chain in resp.chains: + assert len(chain) == 3 + assert chain[0].name == "source" + assert chain[2].name == "target" + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_relation_no_connection(): + """Test when there's no connection between functions.""" + async with lsp_interaction_context(BasedpyrightClient) as interaction: # type: ignore + content = ''' +def isolated_a(): + """Isolated function A.""" + return 1 + +def isolated_b(): + """Isolated function B.""" + return 2 +''' + file_path = await interaction.create_file("no_connection.py", content) + + await anyio.sleep(1) + + capability = RelationCapability(client=interaction.client) # type: ignore + + req = RelationRequest( + source=Locate( + file_path=file_path, + scope=SymbolScope(symbol_path=["isolated_a"]), + ), + target=Locate( + file_path=file_path, + scope=SymbolScope(symbol_path=["isolated_b"]), + ), + max_depth=5, + ) + + resp = await capability(req) + + assert resp is not None + assert resp.source.name == "isolated_a" + assert resp.target.name == "isolated_b" + assert len(resp.chains) == 0, "Should find no chains" + + # Verify markdown shows no connection + markdown = resp.format() + assert "No connection found" in markdown + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_relation_max_depth_limit(): + """Test that max_depth is respected.""" + async with lsp_interaction_context(BasedpyrightClient) as interaction: # type: ignore + content = ''' +def level_0(): + """Level 0.""" + return level_1() + +def level_1(): + """Level 1.""" + return level_2() + +def level_2(): + """Level 2.""" + return level_3() + +def level_3(): + """Level 3.""" + return "deep" +''' + file_path = await interaction.create_file("max_depth.py", content) + + await anyio.sleep(1) + + capability = RelationCapability(client=interaction.client) # type: ignore + + # With max_depth=2, can reach level_2 + req = RelationRequest( + source=Locate( + file_path=file_path, + scope=SymbolScope(symbol_path=["level_0"]), + ), + target=Locate( + file_path=file_path, + scope=SymbolScope(symbol_path=["level_2"]), + ), + max_depth=2, + ) + + resp = await capability(req) + assert resp is not None + assert len(resp.chains) == 1 + assert len(resp.chains[0]) == 3 # level_0 -> level_1 -> level_2 + + # With max_depth=1, cannot reach level_2 + req_shallow = RelationRequest( + source=Locate( + file_path=file_path, + scope=SymbolScope(symbol_path=["level_0"]), + ), + target=Locate( + file_path=file_path, + scope=SymbolScope(symbol_path=["level_2"]), + ), + max_depth=1, + ) + + resp_shallow = await capability(req_shallow) + assert resp_shallow is not None + assert len(resp_shallow.chains) == 0, "Should not find path within max_depth=1" + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_relation_with_class_methods(): + """Test finding call chains involving class methods.""" + async with lsp_interaction_context(BasedpyrightClient) as interaction: # type: ignore + content = ''' +class Calculator: + """A simple calculator class.""" + + def add(self, a: int, b: int) -> int: + """Add two numbers.""" + return self._internal_add(a, b) + + def _internal_add(self, a: int, b: int) -> int: + """Internal addition method.""" + return a + b + +def use_calculator(): + """Use the calculator.""" + calc = Calculator() + return calc.add(1, 2) +''' + file_path = await interaction.create_file("class_methods.py", content) + + await anyio.sleep(1.5) + + capability = RelationCapability(client=interaction.client) # type: ignore + + # Find path from use_calculator to _internal_add + req = RelationRequest( + source=Locate( + file_path=file_path, + scope=SymbolScope(symbol_path=["use_calculator"]), + ), + target=Locate( + file_path=file_path, + scope=SymbolScope(symbol_path=["Calculator", "_internal_add"]), + ), + max_depth=5, + ) + + resp = await capability(req) + + assert resp is not None + assert resp.source.name == "use_calculator" + assert resp.target.name == "_internal_add" + # Should find path: use_calculator -> Calculator.add -> Calculator._internal_add + if len(resp.chains) > 0: + assert resp.chains[0][0].name == "use_calculator" + assert resp.chains[0][-1].name == "_internal_add" + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_relation_cross_file(): + """Test finding call chains across multiple files.""" + async with lsp_interaction_context(BasedpyrightClient) as interaction: # type: ignore + # Create module with utility function + util_content = ''' +def utility_function(): + """A utility function.""" + return "utility" +''' + util_path = await interaction.create_file("utils.py", util_content) + + # Create main file that imports and calls utility + main_content = ''' +from utils import utility_function + +def main(): + """Main function.""" + return utility_function() +''' + main_path = await interaction.create_file("main.py", main_content) + + await anyio.sleep(2) + + capability = RelationCapability(client=interaction.client) # type: ignore + + # Find path from main to utility_function + req = RelationRequest( + source=Locate( + file_path=main_path, + scope=SymbolScope(symbol_path=["main"]), + ), + target=Locate( + file_path=util_path, + scope=SymbolScope(symbol_path=["utility_function"]), + ), + max_depth=5, + ) + + resp = await capability(req) + + assert resp is not None + assert resp.source.name == "main" + assert resp.target.name == "utility_function" + # Should find direct call: main -> utility_function + if len(resp.chains) > 0: + assert len(resp.chains[0]) == 2 + assert resp.chains[0][0].name == "main" + assert resp.chains[0][1].name == "utility_function" + # Verify different files + assert resp.chains[0][0].file_path != resp.chains[0][1].file_path + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_relation_invalid_source(): + """Test when source symbol cannot be found.""" + async with lsp_interaction_context(BasedpyrightClient) as interaction: # type: ignore + content = ''' +def valid_function(): + """A valid function.""" + return 42 +''' + file_path = await interaction.create_file("invalid_source.py", content) + + await anyio.sleep(1) + + capability = RelationCapability(client=interaction.client) # type: ignore + + req = RelationRequest( + source=Locate( + file_path=file_path, + scope=SymbolScope(symbol_path=["nonexistent_function"]), + ), + target=Locate( + file_path=file_path, + scope=SymbolScope(symbol_path=["valid_function"]), + ), + max_depth=5, + ) + + resp = await capability(req) + + # Should return None when source cannot be located + assert resp is None + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_relation_invalid_target(): + """Test when target symbol cannot be found.""" + async with lsp_interaction_context(BasedpyrightClient) as interaction: # type: ignore + content = ''' +def valid_function(): + """A valid function.""" + return 42 +''' + file_path = await interaction.create_file("invalid_target.py", content) + + await anyio.sleep(1) + + capability = RelationCapability(client=interaction.client) # type: ignore + + req = RelationRequest( + source=Locate( + file_path=file_path, + scope=SymbolScope(symbol_path=["valid_function"]), + ), + target=Locate( + file_path=file_path, + scope=SymbolScope(symbol_path=["nonexistent_function"]), + ), + max_depth=5, + ) + + resp = await capability(req) + + # Should return None when target cannot be located + assert resp is None + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_relation_different_path_lengths(): + """Test finding paths of different lengths to the same target.""" + async with lsp_interaction_context(BasedpyrightClient) as interaction: # type: ignore + content = ''' +def source(): + """Source with multiple paths of different lengths.""" + direct() + path1() + path2() + +def direct(): + """Direct path to target.""" + return target() + +def path1(): + """One hop to target.""" + return path1_helper() + +def path1_helper(): + """Helper for path1.""" + return target() + +def path2(): + """Two hops to target.""" + return path2_a() + +def path2_a(): + """First hop in path2.""" + return path2_b() + +def path2_b(): + """Second hop in path2.""" + return target() + +def target(): + """The target function.""" + return 100 +''' + file_path = await interaction.create_file("different_lengths.py", content) + + await anyio.sleep(1.5) + + capability = RelationCapability(client=interaction.client) # type: ignore + + req = RelationRequest( + source=Locate( + file_path=file_path, + scope=SymbolScope(symbol_path=["source"]), + ), + target=Locate( + file_path=file_path, + scope=SymbolScope(symbol_path=["target"]), + ), + max_depth=10, + ) + + resp = await capability(req) + + assert resp is not None + assert resp.source.name == "source" + assert resp.target.name == "target" + # Should find multiple paths + assert len(resp.chains) >= 1, "Should find at least one path" + + # Verify we have paths of different lengths + path_lengths = sorted([len(chain) for chain in resp.chains]) + # Should have paths like: source->direct->target (3), source->path1->path1_helper->target (4), etc. + assert len(set(path_lengths)) > 1, "Should have paths of different lengths" diff --git a/tests/test_rename_glob.py b/tests/test_rename_glob.py index ceeda0d..abb5786 100644 --- a/tests/test_rename_glob.py +++ b/tests/test_rename_glob.py @@ -15,12 +15,13 @@ from lsprotocol.types import ( Range as LSPRange, ) -from test_rename_e2e import E2ERenameClient from lsap.capability.rename import RenameExecuteCapability, RenamePreviewCapability from lsap.schema.locate import Locate from lsap.schema.rename import RenameExecuteRequest, RenamePreviewRequest +from .test_rename_e2e import E2ERenameClient + class GlobPatternRenameClient(E2ERenameClient): """Client that returns edits for multiple files with different paths."""