diff --git a/docs/schemas/draft/relation.md b/docs/schemas/draft/relation.md new file mode 100644 index 0000000..6b49e96 --- /dev/null +++ b/docs/schemas/draft/relation.md @@ -0,0 +1,96 @@ +# Relation API + +**Core Value**: Trace the call path between two symbols — answering "how does A reach B?" + +This is a high-value query for: + +- **Code Flow Understanding**: How does `handle_request` eventually call `db.query`? +- **Impact Analysis**: If I modify function X, which entry points are affected? +- **Architecture Validation**: Verify that module A never directly/indirectly calls module B + +## RelationRequest + +| Field | Type | Default | Description | +| :---------- | :---------------------------- | :------- | :--------------------------------------- | +| `source` | [`Locate`](../locate.md) | Required | The starting symbol for the path search. | +| `target` | [`Locate`](../locate.md) | Required | The ending symbol for the path search. | +| `max_depth` | `number` | `10` | Maximum search depth. | + +## RelationResponse + +| Field | Type | Description | +| :---------- | :---------------- | :---------------------------------------------------- | +| `request` | `RelationRequest` | The original request. | +| `source` | `ChainNode` | The resolved source symbol. | +| `target` | `ChainNode` | The resolved target symbol. | +| `chains` | `ChainNode[][]` | All paths found. Each path is a sequence of nodes. | + +The maximum depth used for the search is available as `request.max_depth`, since the response includes the original `RelationRequest`. + +### ChainNode + +A lightweight symbol representation for path display: + +| Field | Type | Description | +| :---------- | :------- | :--------------------------------- | +| `name` | `string` | Symbol name (e.g., `get_user`) | +| `kind` | `string` | Symbol kind (e.g., `Function`) | +| `file_path` | `Path` | File containing the symbol | +| `detail` | `string` | Optional: signature or extra info | + +> **Design Note**: Unlike `HierarchyItem`, `ChainNode` has no `level` or `is_cycle` fields — the array index naturally represents position in the chain. + +## Example + +### How does `handle_request` reach `db.query`? + +#### Request + +```json +{ + "source": { + "file_path": "src/controllers.py", + "scope": { "symbol_path": ["handle_request"] } + }, + "target": { + "file_path": "src/db.py", + "scope": { "symbol_path": ["query"] } + }, + "max_depth": 5 +} +``` + +#### Response (Markdown Rendered) + +```markdown +# Relation: `handle_request` → `query` + +Found 2 call chain(s): + +### Chain 1 +1. **handle_request** (`Function`) - `src/controllers.py` +2. **UserService.get_user** (`Method`) - `src/services/user.py` +3. **db.query** (`Function`) - `src/db.py` + +### Chain 2 +1. **handle_request** (`Function`) - `src/controllers.py` +2. **AuthService.validate_token** (`Method`) - `src/services/auth.py` +3. **SessionManager.get_session** (`Method`) - `src/services/session.py` +4. **db.query** (`Function`) - `src/db.py` +``` + +## Implementation + +Orchestrates LSP `Call Hierarchy` requests: + +1. **Resolve**: Use `textDocument/prepareCallHierarchy` to get `CallHierarchyItem` for both endpoints +2. **Search**: BFS via `callHierarchy/outgoingCalls` from source (optionally bidirectional with `incomingCalls` from target) +3. **Reconstruct**: Build paths when frontiers meet; filter by `max_depth` + +## Design Decisions + +| Decision | Rationale | +| :------- | :-------- | +| No pagination | `max_depth` bounds result size; path enumeration is typically small | +| `ChainNode` over `HierarchyItem` | Chains are linear; no tree semantics needed | +| Bidirectional search optional | Optimization for large graphs; not required for correctness | diff --git a/schema/explore.md b/schema/explore.md new file mode 100644 index 0000000..d030b2d --- /dev/null +++ b/schema/explore.md @@ -0,0 +1,91 @@ +# Explore API + +## Overview +The Explore API provides "what's around" information for a specific code element. It builds a relationship map of the symbol's neighborhood, including its siblings, dependencies, dependents, and hierarchical position. This API is designed to help Agents build a mental map of the code structure and discover related architectural context. + +## When to Use +- **Building a mental map**: When an Agent first encounters a new class or module and needs to understand its role in the system. +- **Impact analysis**: When an Agent wants to see what other components depend on a specific class before making changes. +- **Discovering related code**: When an Agent is looking for similar implementations or helper classes in the same neighborhood. +- **Architectural context**: When an Agent needs to understand the "big picture" of how a component fits into the overall project structure. + +## Key Differences from Hierarchy API +| Feature | Explore API | Hierarchy API | +|---------|-------------|---------------| +| **Scope** | Neighborhood & Relationships | Deep Tree Traversal | +| **Direction** | Multi-directional (Up, Down, Sideways) | Vertical (Parent/Child) | +| **Goal** | Contextual Mapping | Structural Navigation | +| **Content** | Siblings, Callers, Callees, Types | Call/Type Hierarchy Tree | + +## Relationship Types +- **Hierarchy**: The parent/child relationship (e.g., class members, module contents). +- **Calls**: Outgoing calls (dependencies) and incoming calls (dependents). +- **Types**: Type relationships (e.g., base classes, interface implementations). +- **Siblings**: Other symbols defined in the same scope or file. + +## Request Schema +- `locate`: The target symbol to explore. +- `depth`: (Integer) How many levels of relationships to traverse (default: 1). +- `relationship_types`: (List of Strings) Which types of relationships to include (e.g., `["calls", "hierarchy", "siblings"]`). + +## Response Schema +- `center`: The symbol at the center of the exploration. +- `relationships`: A structured map of related symbols grouped by relationship type. +- `summary`: A high-level description of the symbol's neighborhood. + +## Example: Exploring a class +**Request:** +```json +{ + "locate": { + "file_path": "src/models/user.py", + "scope": { + "symbol_path": ["User"] + } + }, + "relationship_types": ["hierarchy", "siblings", "calls"], + "depth": 1 +} +``` + +**Response:** +```json +{ + "center": { "name": "User", "kind": "class" }, + "relationships": { + "hierarchy": [ + { "name": "User.validate", "kind": "method" }, + { "name": "User.save", "kind": "method" } + ], + "siblings": [ + { "name": "UserRole", "kind": "enum" }, + { "name": "AnonymousUser", "kind": "class" } + ], + "dependents": [ + { "name": "AuthService", "kind": "class", "file": "src/services/auth.py" } + ] + } +} +``` + +**Markdown Output:** +```markdown +# Explore: `User` (class) + +`User` is a central model in `src/models/user.py`, primarily used by `AuthService`. + +## Hierarchy (Members) +- `validate` (method) +- `save` (method) + +## Siblings (In same file) +- `UserRole` (enum) +- `AnonymousUser` (class) + +## Dependents (Used by) +- `AuthService` (src/services/auth.py) +``` + +## See Also +- [Outline API](./outline.md) +- [Reference API](./reference.md) diff --git a/schema/inspect.md b/schema/inspect.md new file mode 100644 index 0000000..6c26cc8 --- /dev/null +++ b/schema/inspect.md @@ -0,0 +1,89 @@ +# Inspect API + +## Overview +The Inspect API provides "how to use" information for a specific symbol. While the Symbol API focuses on "what it is" (implementation details), the Inspect API is designed to help Agents understand how to correctly call or interact with a symbol by providing its signature, documentation, and real-world usage examples. + +## When to Use +- **Learning to call an API**: When an Agent needs to know the parameters, types, and return values of a function. +- **Understanding usage patterns**: When an Agent wants to see how other parts of the codebase call a specific method to avoid common mistakes. +- **Contextual documentation**: When the Agent needs a high-level summary of a symbol's purpose without reading the entire implementation. + +## Key Differences from Symbol API +| Feature | Inspect API | Symbol API | +|---------|-------------|------------| +| **Primary Focus** | How to use the symbol | What the symbol is/does | +| **Content** | Signature, Docstring, Usage Examples | Full Source Code, Implementation | +| **Goal** | Integration & Calling | Understanding Implementation | +| **Context** | External perspective | Internal perspective | + +## Request Schema +- `locate`: The target symbol to inspect (supports `file_path`, `symbol_path`, `find`, etc.). +- `include_usage`: (Boolean) Whether to include real-world usage examples from the codebase. +- `max_usage_examples`: (Integer) Maximum number of usage examples to return. + +## Response Schema +- `symbol`: Basic information about the symbol (name, kind, location). +- `signature`: The formal signature of the symbol (e.g., function parameters and return types). +- `documentation`: The docstring or comments associated with the symbol. +- `usages`: A list of code snippets showing how the symbol is used elsewhere. + +## Example: Inspecting a function +**Request:** +```json +{ + "locate": { + "file_path": "src/utils/auth.py", + "scope": { + "symbol_path": ["verify_token"] + } + }, + "include_usage": true, + "max_usage_examples": 2 +} +``` + +**Response:** +```json +{ + "symbol": { + "name": "verify_token", + "kind": "function", + "location": { "uri": "file:///src/utils/auth.py", "range": { ... } } + }, + "signature": "def verify_token(token: str, secret: str = None) -> UserPayload", + "documentation": "Verifies a JWT token and returns the decoded payload.\n\n:param token: The JWT string to verify.\n:param secret: Optional secret override.", + "usages": [ + { + "file_path": "src/api/middleware.py", + "line": 42, + "code": "payload = verify_token(auth_header.split(' ')[1])" + } + ] +} +``` + +**Markdown Output:** +```markdown +# Inspect: `verify_token` + +**Signature:** `def verify_token(token: str, secret: str = None) -> UserPayload` + +--- + +Verifies a JWT token and returns the decoded payload. + +:param token: The JWT string to verify. +:param secret: Optional secret override. + +## Usage Examples + +### src/api/middleware.py:42 +```python +41 | auth_header = request.headers.get('Authorization') +42 | payload = verify_token(auth_header.split(' ')[1]) +43 | request.user = payload +``` + +## See Also +- [Symbol API](./symbol.md) +- [Reference API](./reference.md) diff --git a/src/lsap/capability/__init__.py b/src/lsap/capability/__init__.py index f7b97be..05071b5 100644 --- a/src/lsap/capability/__init__.py +++ b/src/lsap/capability/__init__.py @@ -4,6 +4,7 @@ from .locate import LocateCapability from .outline import OutlineCapability from .reference import ReferenceCapability +from .relation import RelationCapability from .rename import RenameExecuteCapability, RenamePreviewCapability from .search import SearchCapability from .symbol import SymbolCapability @@ -14,6 +15,7 @@ class Capabilities(TypedDict): locate: LocateCapability outline: OutlineCapability references: ReferenceCapability + relation: RelationCapability rename_preview: RenamePreviewCapability rename_execute: RenameExecuteCapability search: SearchCapability diff --git a/src/lsap/capability/relation.py b/src/lsap/capability/relation.py new file mode 100644 index 0000000..8e8b8c2 --- /dev/null +++ b/src/lsap/capability/relation.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +from collections import deque +from functools import cached_property +from pathlib import Path +from typing import override + +from attrs import define +from lsp_client.capability.request import WithRequestCallHierarchy +from lsprotocol.types import ( + CallHierarchyItem, + CallHierarchyOutgoingCallsParams, +) + +from lsap.schema.locate import LocateRequest +from lsap.schema.relation import ChainNode, RelationRequest, RelationResponse +from lsap.utils.capability import ensure_capability + +from .abc import Capability +from .locate import LocateCapability + + +@define +class RelationCapability(Capability[RelationRequest, RelationResponse]): + @cached_property + def locate(self) -> LocateCapability: + return LocateCapability(self.client) + + @override + async def __call__(self, req: RelationRequest) -> RelationResponse | None: + from lsap.exception import NotFoundError + + # Resolve source symbol + source_req = LocateRequest(locate=req.source) + try: + source_loc = await self.locate(source_req) + if not source_loc: + return None + except NotFoundError: + return None + + # Resolve target symbol + target_req = LocateRequest(locate=req.target) + try: + target_loc = await self.locate(target_req) + if not target_loc: + return None + except NotFoundError: + return None + + # Get CallHierarchyItems for source and target + call_hierarchy = ensure_capability(self.client, WithRequestCallHierarchy) + + source_items = await call_hierarchy.prepare_call_hierarchy( + source_loc.file_path, source_loc.position.to_lsp() + ) + if not source_items: + return None + + target_items = await call_hierarchy.prepare_call_hierarchy( + target_loc.file_path, target_loc.position.to_lsp() + ) + if not target_items: + return None + + # For simplicity, use the first (primary) symbol if multiple exist + # This handles cases like overloaded functions or template instantiations + # TODO: Consider searching paths for all source-target combinations if needed + source_node = self._to_chain_node(source_items[0]) + target_node = self._to_chain_node(target_items[0]) + + # Find all paths from source to target + chains = await self._find_paths( + call_hierarchy, source_items[0], target_items[0], req.max_depth + ) + + return RelationResponse( + request=req, + source=source_node, + target=target_node, + chains=chains, + ) + + def _to_chain_node(self, item: CallHierarchyItem) -> ChainNode: + """Convert CallHierarchyItem to ChainNode""" + file_path = Path(self.client.from_uri(item.uri, relative=False)) + return ChainNode( + name=item.name, + kind=item.kind.name if hasattr(item.kind, "name") else str(item.kind), + file_path=file_path, + detail=item.detail, + ) + + async def _find_paths( + self, + call_hierarchy: WithRequestCallHierarchy, + source: CallHierarchyItem, + target: CallHierarchyItem, + max_depth: int, + ) -> list[list[ChainNode]]: + """ + Find all paths from source to target using BFS. + + Returns a list of chains, where each chain is a list of ChainNodes + representing a path from source to target. + """ + target_key = self._item_key(target) + found_chains: list[list[ChainNode]] = [] + + # BFS queue: (current_item, path_so_far, depth) + queue: deque[tuple[CallHierarchyItem, list[ChainNode], int]] = deque() + queue.append((source, [self._to_chain_node(source)], 0)) + + # Track visited nodes to avoid cycles + visited: set[str] = set() + + while queue: + current_item, path, depth = queue.popleft() + current_key = self._item_key(current_item) + + # Check if we've reached the target + if current_key == target_key: + found_chains.append(path) + continue + + # Skip if we've exceeded max depth + if depth >= max_depth: + continue + + # Skip if already visited (cycle detection) + if current_key in visited: + continue + visited.add(current_key) + + # Get outgoing calls from current item + outgoing_calls = ( + await call_hierarchy._request_call_hierarchy_outgoing_calls( + CallHierarchyOutgoingCallsParams(item=current_item) + ) + ) + + if not outgoing_calls: + continue + + # Add each outgoing call to the queue + for call in outgoing_calls: + next_item = call.to + next_key = self._item_key(next_item) + # Skip if already visited to prevent redundant queue entries + if next_key not in visited: + next_node = self._to_chain_node(next_item) + next_path = [*path, next_node] + queue.append((next_item, next_path, depth + 1)) + + return found_chains + + def _item_key(self, item: CallHierarchyItem) -> str: + """Generate a unique key for a CallHierarchyItem""" + file_path = self.client.from_uri(item.uri, relative=False) + return f"{file_path}:{item.range.start.line}:{item.range.start.character}:{item.name}" diff --git a/src/lsap/schema/draft/explore.py b/src/lsap/schema/draft/explore.py new file mode 100644 index 0000000..ebd299b --- /dev/null +++ b/src/lsap/schema/draft/explore.py @@ -0,0 +1,182 @@ +""" +# Explore API + +The Explore API provides relationship-oriented code exploration, answering "what's around this symbol" +(siblings, dependencies, hierarchy) rather than just "what it is" (definition). + +## Example Usage + +### Scenario 1: Exploring siblings and dependencies of a class + +Request: + +```json +{ + "locate": { + "file_path": "src/models.py", + "scope": { + "symbol_path": ["User"] + } + }, + "include": ["siblings", "dependencies"], + "max_items": 10 +} +``` + +### Scenario 2: Exploring class hierarchy and calls + +Request: + +```json +{ + "locate": { + "file_path": "src/services.py", + "scope": { + "symbol_path": ["AuthService"] + } + }, + "include": ["hierarchy", "calls"], + "resolve_info": true +} +``` +""" + +from typing import Final, Literal + +from pydantic import BaseModel, ConfigDict, Field + +from .._abc import Response +from ..locate import LocateRequest +from ..models import CallHierarchy, SymbolInfo + + +class HierarchyInfo(BaseModel): + """Information about the inheritance hierarchy of a symbol.""" + + parents: list[SymbolInfo] = Field( + default_factory=list, description="Parent classes or interfaces" + ) + children: list[SymbolInfo] = Field( + default_factory=list, description="Child classes or implementations" + ) + + +class ExploreRequest(LocateRequest): + """ + Request to explore relationships around a symbol. + + Provides information about siblings, dependencies, hierarchy, and calls. + """ + + include: list[ + Literal["siblings", "dependencies", "dependents", "hierarchy", "calls"] + ] = Field(default=["siblings", "dependencies"]) + """Types of relationships to include in the exploration.""" + + max_items: int = Field(default=10, ge=1, le=50) + """Maximum number of items to return for each relationship type.""" + + resolve_info: bool = False + """Whether to resolve detailed information for symbols.""" + + include_external: bool = False + """Whether to include external dependencies if available.""" + + +markdown_template: Final = """ +# Explore: `{{ target.path | join: "." }}` (`{{ target.kind }}`) + +{% if siblings.size > 0 -%} +## Siblings +{% for item in siblings -%} +- `{{ item.name }}` (`{{ item.kind }}`) {% if item.range != nil %}at line {{ item.range.start.line | plus: 1 }}{% endif %} +{% endfor -%} +{%- endif %} + +{% if dependencies.size > 0 -%} +## Dependencies +{% for item in dependencies -%} +- `{{ item.name }}` (`{{ item.kind }}`) in `{{ item.file_path }}` +{% endfor -%} +{%- endif %} + +{% if dependents.size > 0 -%} +## Dependents +{% for item in dependents -%} +- `{{ item.name }}` (`{{ item.kind }}`) in `{{ item.file_path }}` +{% endfor -%} +{%- endif %} + +{% if hierarchy != nil -%} +## Hierarchy +{% if hierarchy.parents.size > 0 -%} +### Parents +{% for item in hierarchy.parents -%} +- `{{ item.name }}` (`{{ item.kind }}`) in `{{ item.file_path }}` +{% endfor -%} +{%- endif %} + +{% if hierarchy.children.size > 0 -%} +### Children +{% for item in hierarchy.children -%} +- `{{ item.name }}` (`{{ item.kind }}`) in `{{ item.file_path }}` +{% endfor -%} +{%- endif %} +{%- endif %} + +{% if calls != nil -%} +## Call Hierarchy +{% if calls.incoming.size > 0 -%} +### Incoming Calls +{% for item in calls.incoming -%} +- `{{ item.name }}` (`{{ item.kind }}`) at `{{ item.file_path }}:{{ item.range.start.line | plus: 1 }}` +{% endfor -%} +{%- endif %} + +{% if calls.outgoing.size > 0 -%} +### Outgoing Calls +{% for item in calls.outgoing -%} +- `{{ item.name }}` (`{{ item.kind }}`) at `{{ item.file_path }}:{{ item.range.start.line | plus: 1 }}` +{% endfor -%} +{%- endif %} +{%- endif %} + +--- +> [!TIP] +> Use this map to understand the architectural context and impact of changes to this symbol. +""" + + +class ExploreResponse(Response): + """Response containing relationship-oriented information about a symbol.""" + + target: SymbolInfo + """The symbol being explored.""" + + siblings: list[SymbolInfo] = Field(default_factory=list) + """Symbols defined in the same scope or file.""" + + dependencies: list[SymbolInfo] = Field(default_factory=list) + """Symbols that this symbol depends on.""" + + dependents: list[SymbolInfo] = Field(default_factory=list) + """Symbols that depend on this symbol.""" + + hierarchy: HierarchyInfo | None = None + """Inheritance hierarchy information.""" + + calls: CallHierarchy | None = None + """Call hierarchy information.""" + + model_config = ConfigDict( + json_schema_extra={ + "markdown": markdown_template, + } + ) + + +__all__ = [ + "ExploreRequest", + "ExploreResponse", + "HierarchyInfo", +] diff --git a/src/lsap/schema/draft/inspect.py b/src/lsap/schema/draft/inspect.py new file mode 100644 index 0000000..a53a161 --- /dev/null +++ b/src/lsap/schema/draft/inspect.py @@ -0,0 +1,166 @@ +""" +# Inspect API + +The Inspect API provides "how to use" information for a symbol, including usage examples, +signatures, and documentation. It is designed to help Agents understand how to correctly +invoke or interact with a symbol. + +## Example Usage + +### Scenario 1: Inspecting a function for usage examples + +Request: + +```json +{ + "locate": { + "file_path": "src/utils.py", + "scope": { + "symbol_path": ["format_date"] + } + }, + "include_examples": 5, + "include_signature": true +} +``` + +### Scenario 2: Inspecting a class with call hierarchy + +Request: + +```json +{ + "locate": { + "file_path": "src/models.py", + "scope": { + "symbol_path": ["User"] + } + }, + "include_call_hierarchy": true +} +``` +""" + +from typing import Final + +from pydantic import BaseModel, ConfigDict, Field + +from .._abc import Response +from ..locate import LocateRequest +from ..models import CallHierarchy, Location, SymbolDetailInfo, SymbolInfo + + +class UsageExample(BaseModel): + """A code snippet showing how a symbol is used in context.""" + + code: str = Field(..., description="Code snippet with context") + context: SymbolInfo | None = Field(None, description="Where this usage occurs") + location: Location = Field(..., description="Exact position of the usage") + call_pattern: str | None = Field( + None, description="Extracted pattern like 'func(arg1, arg2)'" + ) + + +class InspectRequest(LocateRequest): + """ + Request to inspect a symbol for usage-oriented information. + + Provides signatures, documentation, and real-world usage examples from the codebase. + """ + + include_examples: int = Field(default=3, ge=0, le=20) + """Number of usage examples to include.""" + + include_signature: bool = True + """Whether to include the symbol's signature.""" + + include_doc: bool = True + """Whether to include the symbol's documentation (hover).""" + + include_call_hierarchy: bool = False + """Whether to include call hierarchy information.""" + + include_external: bool = False + """Whether to include examples from external libraries if available.""" + + context_lines: int = Field(default=2, ge=0, le=10) + """Number of context lines to include around each example.""" + + +markdown_template: Final = """ +# Inspect: `{{ info.path | join: "." }}` (`{{ info.kind }}`) + +{% if signature != nil -%} +## Signature +```python +{{ signature }} +``` +{%- endif %} + +{% if info.hover != nil -%} +## Documentation +{{ info.hover }} +{%- endif %} + +{% if examples.size > 0 -%} +## Usage Examples +{% for example in examples -%} +### Example {{ forloop.index }} +{% if example.context != nil -%} +In `{{ example.context.path | join: "." }}` (`{{ example.context.kind }}`) at `{{ example.location.file_path }}:{{ example.location.range.start.line }}` +{%- else -%} +At `{{ example.location.file_path }}:{{ example.location.range.start.line }}` +{%- endif %} + +{% if example.call_pattern != nil -%} +Pattern: `{{ example.call_pattern }}` +{%- endif %} + +```{{ example.location.file_path.suffix | remove_first: "." }} +{{ example.code }} +``` +{% endfor -%} +{%- endif %} + +{% if call_hierarchy != nil -%} +{% if call_hierarchy.incoming.size > 0 -%} +## Incoming Calls +{% for item in call_hierarchy.incoming -%} +- `{{ item.name }}` (`{{ item.kind }}`) at `{{ item.file_path }}:{{ item.range.start.line }}` +{% endfor -%} +{%- endif %} + +{% if call_hierarchy.outgoing.size > 0 -%} +## Outgoing Calls +{% for item in call_hierarchy.outgoing -%} +- `{{ item.name }}` (`{{ item.kind }}`) at `{{ item.file_path }}:{{ item.range.start.line }}` +{% endfor -%} +{%- endif %} +{%- endif %} + +--- +> [!TIP] +> Use these examples to understand the expected arguments and common calling patterns for this symbol. +""" + + +class InspectResponse(Response): + """Response containing usage-oriented information about a symbol.""" + + info: SymbolDetailInfo + signature: str | None = None + examples: list[UsageExample] = Field(default_factory=list) + call_hierarchy: CallHierarchy | None = None + + model_config = ConfigDict( + json_schema_extra={ + "markdown": markdown_template, + } + ) + + +__all__ = [ + "InspectRequest", + "InspectResponse", + "UsageExample", +] diff --git a/src/lsap/schema/relation.py b/src/lsap/schema/relation.py new file mode 100644 index 0000000..d9c2bf3 --- /dev/null +++ b/src/lsap/schema/relation.py @@ -0,0 +1,83 @@ +from pathlib import Path +from typing import Final + +from pydantic import BaseModel, ConfigDict + +from lsap.schema._abc import Request, Response +from lsap.schema.locate import Locate + + +class ChainNode(BaseModel): + """ + A node in a call chain. + """ + + name: str + kind: str + file_path: Path + detail: str | None = None + + +class RelationRequest(Request): + """ + Finds call chains connecting two symbols. + + Answers the question: "How does A reach B?" + + Use cases: + - Code flow understanding: How does handle_request reach db.query? + - Impact analysis: Which entry points are affected if I modify X? + - Architecture validation: Verify module A never calls module B + """ + + source: Locate + target: Locate + + max_depth: int = 10 + """Maximum depth to search for connections""" + + +markdown_template: Final = """ +# Relation: `{{ source.name }}` → `{{ target.name }}` + +{% if chains.size > 0 %} +Found {{ chains.size }} call chain(s): + +{% for chain in chains %} +### Chain {{ forloop.index }} +{% for node in chain %} +{{ forloop.index }}. **{{ node.name }}** (`{{ node.kind }}`) - `{{ node.file_path }}` +{%- if node.detail %} — {{ node.detail }}{% endif %} +{% endfor %} +{% endfor %} +{% else %} +No connection found between `{{ source.name }}` and `{{ target.name }}` within depth {{ request.max_depth }}. +{% endif %} +""" + + +class RelationResponse(Response): + request: RelationRequest + """The original request""" + + source: ChainNode + """The resolved source symbol""" + + target: ChainNode + """The resolved target symbol""" + + chains: list[list[ChainNode]] + """All paths found. Each path is a sequence of nodes from source to target.""" + + model_config = ConfigDict( + json_schema_extra={ + "markdown": markdown_template, + } + ) + + +__all__ = [ + "ChainNode", + "RelationRequest", + "RelationResponse", +] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/framework/__init__.py b/tests/framework/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/framework/lsp.py b/tests/framework/lsp.py new file mode 100644 index 0000000..1050228 --- /dev/null +++ b/tests/framework/lsp.py @@ -0,0 +1,595 @@ +from __future__ import annotations + +import tempfile +from collections.abc import AsyncGenerator, Sequence +from contextlib import asynccontextmanager +from pathlib import Path +from typing import Any + +import attrs +from lsp_client.capability.request.completion import WithRequestCompletion +from lsp_client.capability.request.definition import WithRequestDefinition +from lsp_client.capability.request.document_symbol import WithRequestDocumentSymbol +from lsp_client.capability.request.hover import WithRequestHover +from lsp_client.capability.request.reference import WithRequestReferences +from lsp_client.client.abc import Client +from lsp_client.utils.types import Position, Range, lsp_type + +type LspResponse[R] = R | None + + +@attrs.define +class LspInteraction[C: Client]: + client: C + workspace_root: Path + + @property + def resolved_workspace(self) -> Path: + """Get resolved workspace path for path comparisons.""" + return self.workspace_root.resolve() + + def full_path(self, relative_path: str) -> Path: + # Return the original path (not resolved) for file operations + # This is important for symlinked fixtures + return self.workspace_root / relative_path + + def full_path_resolved(self, relative_path: str) -> Path: + """Return resolved path for comparison with pyrefly responses.""" + return (self.workspace_root / relative_path).resolve() + + async def create_file(self, relative_path: str, content: str) -> Path: + path = self.full_path(relative_path) + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(content) + return path + + async def request_definition( + self, relative_path: str, line: int, column: int + ) -> DefinitionAssertion: + assert isinstance(self.client, WithRequestDefinition) + path = self.full_path(relative_path) + resp = await self.client.request_definition( + file_path=path, + position=Position(line=line, character=column), + ) + return DefinitionAssertion(self, resp) + + async def request_hover( + self, relative_path: str, line: int, column: int + ) -> HoverAssertion: + assert isinstance(self.client, WithRequestHover) + path = self.full_path(relative_path) + resp = await self.client.request_hover( + file_path=path, + position=Position(line=line, character=column), + ) + return HoverAssertion(self, resp) + + async def request_completion( + self, relative_path: str, line: int, column: int + ) -> CompletionAssertion: + assert isinstance(self.client, WithRequestCompletion) + path = self.full_path(relative_path) + resp = await self.client.request_completion( + file_path=path, + position=Position(line=line, character=column), + ) + return CompletionAssertion(self, resp) + + async def request_references( + self, relative_path: str, line: int, column: int + ) -> ReferencesAssertion: + assert isinstance(self.client, WithRequestReferences) + path = self.full_path(relative_path) + resp = await self.client.request_references( + file_path=path, + position=Position(line=line, character=column), + ) + return ReferencesAssertion(self, resp) + + async def request_document_symbols( + self, relative_path: str + ) -> DocumentSymbolsAssertion: + assert isinstance(self.client, WithRequestDocumentSymbol) + path = self.full_path(relative_path) + resp = await self.client.request_document_symbol(file_path=path) + return DocumentSymbolsAssertion(self, resp) + + +@attrs.define +class DefinitionAssertion: + interaction: LspInteraction[Any] + response: ( + lsp_type.Location + | Sequence[lsp_type.Location] + | Sequence[lsp_type.LocationLink] + | None + ) + + def expect_definition( + self, + relative_path: str, + start_line: int, + start_col: int, + end_line: int, + end_col: int, + ) -> None: + assert self.response is not None, "Definition response is None" + + # Use resolved path for comparison with pyrefly responses + expected_path = self.interaction.full_path_resolved(relative_path) + expected_range = Range( + start=Position(line=start_line, character=start_col), + end=Position(line=end_line, character=end_col), + ) + + match self.response: + case lsp_type.Location() as loc: + actual_path = Path( + self.interaction.client.from_uri(loc.uri, relative=False) + ) + # Compare using resolved paths to handle symlinks properly + actual_resolved = actual_path.resolve() + expected_resolved = expected_path.resolve() + assert actual_resolved == expected_resolved, ( + f"Expected resolved path {expected_resolved}, got {actual_resolved}" + ) + assert loc.range == expected_range + case list() | Sequence() as locs: + found = False + for loc in locs: + if isinstance(loc, lsp_type.Location): + actual_path = Path( + self.interaction.client.from_uri(loc.uri, relative=False) + ) + actual_range = loc.range + elif isinstance(loc, lsp_type.LocationLink): + actual_path = Path( + self.interaction.client.from_uri( + loc.target_uri, relative=False + ) + ) + actual_range = loc.target_selection_range + else: + continue + + # Compare using resolved paths to handle symlinks + actual_resolved = actual_path.resolve() + expected_resolved = expected_path.resolve() + path_match = actual_resolved == expected_resolved + + if path_match and actual_range == expected_range: + found = True + break + + assert found, ( + f"Definition not found at {expected_path}:{expected_range}" + ) + case _: + raise TypeError( + f"Unexpected definition response type: {type(self.response)}" + ) + + match self.response: + case lsp_type.Location() as loc: + actual_path = Path( + self.interaction.client.from_uri(loc.uri, relative=False) + ) + # Pyrefly may return resolved paths + # Compare using resolved paths to handle symlinks properly + actual_resolved = actual_path.resolve() + expected_resolved = expected_path.resolve() + assert actual_resolved == expected_resolved, ( + f"Expected resolved path {expected_resolved}, got {actual_resolved}" + ) + assert loc.range == expected_range + case list() | Sequence() as locs: + found = False + for loc in locs: + if isinstance(loc, lsp_type.Location): + actual_path = Path( + self.interaction.client.from_uri(loc.uri, relative=False) + ) + actual_range = loc.range + elif isinstance(loc, lsp_type.LocationLink): + actual_path = Path( + self.interaction.client.from_uri( + loc.target_uri, relative=False + ) + ) + actual_range = loc.target_selection_range + else: + continue + + # Compare using resolved paths to handle symlinks + actual_resolved = actual_path.resolve() + expected_resolved = expected_path.resolve() + path_match = actual_resolved == expected_resolved + + if path_match and actual_range == expected_range: + found = True + break + + assert found, ( + f"Definition not found at {expected_path}:{expected_range}" + ) + case _: + raise TypeError( + f"Unexpected definition response type: {type(self.response)}" + ) + + match self.response: + case lsp_type.Location() as loc: + actual_path = Path( + self.interaction.client.from_uri(loc.uri, relative=False) + ) + # Pyrefly may return resolved paths, so we need to handle symlinks + # Try to resolve both paths and compare + try: + actual_resolved = actual_path.resolve() + expected_resolved = expected_path.resolve() + assert actual_resolved == expected_resolved, ( + f"Expected resolved path {expected_resolved}, got {actual_resolved}" + ) + except ValueError: + # Fallback to relative path comparison if resolve fails + actual_rel = actual_path.relative_to( + self.interaction.workspace_root + ) + expected_rel = expected_path.relative_to( + self.interaction.workspace_root + ) + assert actual_rel == expected_rel, ( + f"Expected path {expected_rel}, got {actual_rel}" + ) + assert loc.range == expected_range + case list() | Sequence() as locs: + found = False + for loc in locs: + if isinstance(loc, lsp_type.Location): + actual_path = Path( + self.interaction.client.from_uri(loc.uri, relative=False) + ) + actual_range = loc.range + elif isinstance(loc, lsp_type.LocationLink): + actual_path = Path( + self.interaction.client.from_uri( + loc.target_uri, relative=False + ) + ) + actual_range = loc.target_selection_range + else: + continue + + # Pyrefly may return resolved paths, so we need to handle symlinks + try: + actual_resolved = actual_path.resolve() + expected_resolved = expected_path.resolve() + path_match = actual_resolved == expected_resolved + except ValueError: + # Fallback to relative path comparison + try: + actual_rel = actual_path.relative_to( + self.interaction.workspace_root + ) + expected_rel = expected_path.relative_to( + self.interaction.workspace_root + ) + path_match = actual_rel == expected_rel + except ValueError: + path_match = False + + if path_match and actual_range == expected_range: + found = True + break + + assert found, ( + f"Definition not found at {expected_path}:{expected_range}" + ) + case _: + raise TypeError( + f"Unexpected definition response type: {type(self.response)}" + ) + + match self.response: + case lsp_type.Location() as loc: + actual_path = Path( + self.interaction.client.from_uri(loc.uri, relative=False) + ) + # Compare using relative paths to handle symlinks + try: + actual_rel = actual_path.relative_to( + self.interaction.workspace_root + ) + expected_rel = expected_path.relative_to( + self.interaction.workspace_root + ) + assert actual_rel == expected_rel, ( + f"Expected path {expected_rel}, got {actual_rel}" + ) + except ValueError: + # Fallback to absolute path comparison + assert actual_path.resolve() == expected_path.resolve(), ( + f"Expected path {expected_path}, got {actual_path}" + ) + assert loc.range == expected_range + case list() | Sequence() as locs: + found = False + for loc in locs: + if isinstance(loc, lsp_type.Location): + actual_path = Path( + self.interaction.client.from_uri(loc.uri, relative=False) + ) + actual_range = loc.range + elif isinstance(loc, lsp_type.LocationLink): + actual_path = Path( + self.interaction.client.from_uri( + loc.target_uri, relative=False + ) + ) + actual_range = loc.target_selection_range + else: + continue + + # Compare using relative paths to handle symlinks + try: + actual_rel = actual_path.relative_to( + self.interaction.workspace_root + ) + expected_rel = expected_path.relative_to( + self.interaction.workspace_root + ) + path_match = actual_rel == expected_rel + except ValueError: + # Fallback to absolute path comparison + path_match = actual_path.resolve() == expected_path.resolve() + + if path_match and actual_range == expected_range: + found = True + break + + assert found, ( + f"Definition not found at {expected_path}:{expected_range}" + ) + case _: + raise TypeError( + f"Unexpected definition response type: {type(self.response)}" + ) + + match self.response: + case lsp_type.Location() as loc: + actual_path = Path( + self.interaction.client.from_uri(loc.uri, relative=False) + ) + # Compare using relative paths to handle symlinks + try: + actual_rel = actual_path.relative_to( + self.interaction.workspace_root + ) + expected_rel = expected_path.relative_to( + self.interaction.workspace_root + ) + assert actual_rel == expected_rel, ( + f"Expected path {expected_rel}, got {actual_rel}" + ) + except ValueError: + # Fallback to absolute path comparison + assert actual_path.resolve() == expected_path.resolve(), ( + f"Expected path {expected_path}, got {actual_path}" + ) + assert loc.range == expected_range + case list() | Sequence() as locs: + found = False + for loc in locs: + if isinstance(loc, lsp_type.Location): + actual_path = Path( + self.interaction.client.from_uri(loc.uri, relative=False) + ) + actual_range = loc.range + elif isinstance(loc, lsp_type.LocationLink): + actual_path = Path( + self.interaction.client.from_uri( + loc.target_uri, relative=False + ) + ) + actual_range = loc.target_selection_range + else: + continue + + # Compare using relative paths to handle symlinks + try: + actual_rel = actual_path.relative_to( + self.interaction.workspace_root + ) + expected_rel = expected_path.relative_to( + self.interaction.workspace_root + ) + path_match = actual_rel == expected_rel + except ValueError: + # Fallback to absolute path comparison + path_match = actual_path.resolve() == expected_path.resolve() + + if path_match and actual_range == expected_range: + found = True + break + + assert found, ( + f"Definition not found at {expected_path}:{expected_range}" + ) + case _: + raise TypeError( + f"Unexpected definition response type: {type(self.response)}" + ) + + match self.response: + case lsp_type.Location() as loc: + actual_path = Path( + self.interaction.client.from_uri(loc.uri, relative=False) + ) + # Compare using relative paths to handle symlinks + try: + actual_rel = actual_path.relative_to( + self.interaction.workspace_root + ) + expected_rel = expected_path.relative_to( + self.interaction.workspace_root + ) + assert actual_rel == expected_rel, ( + f"Expected path {expected_rel}, got {actual_rel}" + ) + except ValueError: + # Fallback to absolute path comparison + assert actual_path.resolve() == expected_path.resolve(), ( + f"Expected path {expected_path}, got {actual_path}" + ) + assert loc.range == expected_range + case list() | Sequence() as locs: + found = False + for loc in locs: + if isinstance(loc, lsp_type.Location): + actual_path = Path( + self.interaction.client.from_uri(loc.uri, relative=False) + ) + actual_range = loc.range + elif isinstance(loc, lsp_type.LocationLink): + actual_path = Path( + self.interaction.client.from_uri( + loc.target_uri, relative=False + ) + ) + actual_range = loc.target_selection_range + else: + continue + + # Compare using relative paths to handle symlinks + try: + actual_rel = actual_path.relative_to( + self.interaction.workspace_root + ) + expected_rel = expected_path.relative_to( + self.interaction.workspace_root + ) + path_match = actual_rel == expected_rel + except ValueError: + # Fallback to absolute path comparison + path_match = actual_path.resolve() == expected_path.resolve() + + if path_match and actual_range == expected_range: + found = True + break + + assert found, ( + f"Definition not found at {expected_path}:{expected_range}" + ) + case _: + raise TypeError( + f"Unexpected definition response type: {type(self.response)}" + ) + + +@attrs.define +class HoverAssertion: + interaction: LspInteraction[Any] + response: lsp_type.MarkupContent | None + + def expect_content(self, pattern: str) -> None: + assert self.response is not None, "Hover response is None" + assert pattern in self.response.value, ( + f"Expected '{pattern}' in hover content, got '{self.response.value}'" + ) + + +@attrs.define +class CompletionAssertion: + interaction: LspInteraction[Any] + response: Sequence[lsp_type.CompletionItem] + + def expect_label(self, label: str) -> None: + labels = [item.label for item in self.response] + assert label in labels, ( + f"Expected completion label '{label}' not found in {labels}" + ) + + +@attrs.define +class ReferencesAssertion: + interaction: LspInteraction[Any] + response: Sequence[lsp_type.Location] | None + + def expect_reference( + self, + relative_path: str, + start_line: int, + start_col: int, + end_line: int, + end_col: int, + ) -> None: + assert self.response is not None, "References response is None" + expected_path = self.interaction.full_path(relative_path) + expected_range = Range( + start=Position(line=start_line, character=start_col), + end=Position(line=end_line, character=end_col), + ) + + found = False + for loc in self.response: + actual_path = self.interaction.client.from_uri(loc.uri, relative=False) + if ( + Path(actual_path).resolve() == expected_path + and loc.range == expected_range + ): + found = True + break + assert found, f"Reference not found at {expected_path}:{expected_range}" + + +@attrs.define +class DocumentSymbolsAssertion: + interaction: LspInteraction[Any] + response: ( + Sequence[lsp_type.SymbolInformation] | Sequence[lsp_type.DocumentSymbol] | None + ) + _last_found_names: list[str] = attrs.field(factory=list, init=False) + + def expect_symbol(self, name: str, kind: lsp_type.SymbolKind | None = None) -> None: + assert self.response is not None, "Document symbols response is None" + + def check_symbols( + symbols: Sequence[lsp_type.SymbolInformation] + | Sequence[lsp_type.DocumentSymbol], + found_names: list[str], + ) -> bool: + for sym in symbols: + if isinstance(sym, lsp_type.DocumentSymbol): + found_names.append(f"{sym.name} ({sym.kind})") + if sym.name == name and (kind is None or sym.kind == kind): + return True + if sym.children and check_symbols(sym.children, found_names): + return True + elif isinstance(sym, lsp_type.SymbolInformation): + found_names.append(f"{sym.name} ({sym.kind})") + if sym.name == name and (kind is None or sym.kind == kind): + return True + return False + + self._last_found_names = [] + if not check_symbols(self.response, self._last_found_names): + print(f"Actually found: {self._last_found_names}") + raise AssertionError(f"Symbol '{name}' not found") + + +@asynccontextmanager +async def lsp_interaction_context[C: Client]( + client_cls: type[C], workspace_root: Path | None = None, **client_kwargs: Any +) -> AsyncGenerator[LspInteraction[C], None]: + if workspace_root is None: + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir).resolve() + async with client_cls(workspace=root, **client_kwargs) as client: + yield LspInteraction(client=client, workspace_root=root) + else: + # Use original path to preserve symlinks + # This is important for test fixtures that are symlinked + root = workspace_root + async with client_cls(workspace=root, **client_kwargs) as client: + yield LspInteraction(client=client, workspace_root=root) diff --git a/tests/test_explore_schema.py b/tests/test_explore_schema.py new file mode 100644 index 0000000..88a0363 --- /dev/null +++ b/tests/test_explore_schema.py @@ -0,0 +1,233 @@ +from pathlib import Path + +import pytest +from pydantic import ValidationError + +from lsap.schema.draft.explore import ExploreRequest, ExploreResponse, HierarchyInfo +from lsap.schema.locate import Locate +from lsap.schema.models import ( + CallHierarchy, + CallHierarchyItem, + Position, + Range, + SymbolInfo, + SymbolKind, +) + + +def test_explore_request_defaults(): + """Test ExploreRequest default values.""" + req = ExploreRequest(locate=Locate(file_path=Path("test.py"), find="MyClass")) + assert req.include == ["siblings", "dependencies"] + assert req.max_items == 10 + assert req.resolve_info is False + assert req.include_external is False + + +def test_explore_request_validation(): + """Test ExploreRequest field validation.""" + # Test max_items range (1-50) + with pytest.raises(ValidationError): + ExploreRequest( + locate=Locate(file_path=Path("test.py"), find="MyClass"), max_items=0 + ) + with pytest.raises(ValidationError): + ExploreRequest( + locate=Locate(file_path=Path("test.py"), find="MyClass"), max_items=51 + ) + + +def test_hierarchy_info_model(): + """Test HierarchyInfo model with parents and children.""" + parent = SymbolInfo( + name="Base", + kind=SymbolKind.Class, + file_path=Path("base.py"), + range=Range( + start=Position(line=1, character=1), end=Position(line=10, character=1) + ), + path=["Base"], + ) + child = SymbolInfo( + name="Sub", + kind=SymbolKind.Class, + file_path=Path("sub.py"), + range=Range( + start=Position(line=1, character=1), end=Position(line=10, character=1) + ), + path=["Sub"], + ) + info = HierarchyInfo(parents=[parent], children=[child]) + assert len(info.parents) == 1 + assert len(info.children) == 1 + assert info.parents[0].name == "Base" + assert info.children[0].name == "Sub" + + +def test_explore_response_serialization(): + """Test ExploreResponse serialization and deserialization.""" + target = SymbolInfo( + name="MyClass", + kind=SymbolKind.Class, + file_path=Path("test.py"), + range=Range( + start=Position(line=5, character=1), end=Position(line=15, character=1) + ), + path=["MyClass"], + ) + + sibling = SymbolInfo( + name="OtherClass", + kind=SymbolKind.Class, + file_path=Path("test.py"), + range=Range( + start=Position(line=20, character=1), end=Position(line=30, character=1) + ), + path=["OtherClass"], + ) + + resp = ExploreResponse( + target=target, + siblings=[sibling], + dependencies=[], + dependents=[], + hierarchy=None, + calls=None, + ) + + # Serialize to dict + data = resp.model_dump() + assert data["target"]["name"] == "MyClass" + assert len(data["siblings"]) == 1 + assert data["siblings"][0]["name"] == "OtherClass" + + # Deserialize back + resp2 = ExploreResponse.model_validate(data) + assert resp2.target.name == "MyClass" + assert len(resp2.siblings) == 1 + assert resp2.siblings[0].name == "OtherClass" + + +def test_explore_response_markdown_rendering(): + """Test ExploreResponse.format('markdown') renders correctly.""" + target = SymbolInfo( + name="MyClass", + kind=SymbolKind.Class, + file_path=Path("test.py"), + range=Range( + start=Position(line=5, character=1), end=Position(line=15, character=1) + ), + path=["MyClass"], + ) + + sibling = SymbolInfo( + name="OtherClass", + kind=SymbolKind.Class, + file_path=Path("test.py"), + range=Range( + start=Position(line=20, character=1), end=Position(line=30, character=1) + ), + path=["OtherClass"], + ) + + dependency = SymbolInfo( + name="Helper", + kind=SymbolKind.Class, + file_path=Path("utils.py"), + range=Range( + start=Position(line=1, character=1), end=Position(line=5, character=1) + ), + path=["Helper"], + ) + + dependent = SymbolInfo( + name="Main", + kind=SymbolKind.Class, + file_path=Path("main.py"), + range=Range( + start=Position(line=10, character=1), end=Position(line=20, character=1) + ), + path=["Main"], + ) + + hierarchy = HierarchyInfo( + parents=[ + SymbolInfo( + name="BaseClass", + kind=SymbolKind.Class, + file_path=Path("base.py"), + range=Range( + start=Position(line=1, character=1), + end=Position(line=10, character=1), + ), + path=["BaseClass"], + ) + ], + children=[ + SymbolInfo( + name="SubClass", + kind=SymbolKind.Class, + file_path=Path("sub.py"), + range=Range( + start=Position(line=1, character=1), + end=Position(line=10, character=1), + ), + path=["SubClass"], + ) + ], + ) + + calls = CallHierarchy( + incoming=[ + CallHierarchyItem( + name="caller_func", + kind=SymbolKind.Function, + file_path=Path("caller.py"), + range=Range( + start=Position(line=5, character=1), + end=Position(line=5, character=10), + ), + ) + ], + outgoing=[ + CallHierarchyItem( + name="callee_func", + kind=SymbolKind.Function, + file_path=Path("callee.py"), + range=Range( + start=Position(line=10, character=1), + end=Position(line=10, character=10), + ), + ) + ], + ) + + resp = ExploreResponse( + target=target, + siblings=[sibling], + dependencies=[dependency], + dependents=[dependent], + hierarchy=hierarchy, + calls=calls, + ) + + markdown = resp.format("markdown") + + assert "# Explore: `MyClass` (`class`)" in markdown + assert "## Siblings" in markdown + assert "- `OtherClass` (`class`) at line 21" in markdown + assert "## Dependencies" in markdown + assert "- `Helper` (`class`) in `utils.py`" in markdown + assert "## Dependents" in markdown + assert "- `Main` (`class`) in `main.py`" in markdown + assert "## Hierarchy" in markdown + assert "### Parents" in markdown + assert "- `BaseClass` (`class`) in `base.py`" in markdown + assert "### Children" in markdown + assert "- `SubClass` (`class`) in `sub.py`" in markdown + assert "## Call Hierarchy" in markdown + assert "### Incoming Calls" in markdown + assert "- `caller_func` (`function`) at `caller.py:6`" in markdown + assert "### Outgoing Calls" in markdown + assert "- `callee_func` (`function`) at `callee.py:11`" in markdown + assert "Use this map to understand" in markdown diff --git a/tests/test_inspect_schema.py b/tests/test_inspect_schema.py new file mode 100644 index 0000000..2890895 --- /dev/null +++ b/tests/test_inspect_schema.py @@ -0,0 +1,191 @@ +from pathlib import Path + +import pytest +from pydantic import ValidationError + +from lsap.schema.draft.inspect import InspectRequest, InspectResponse, UsageExample +from lsap.schema.locate import Locate +from lsap.schema.models import ( + CallHierarchy, + CallHierarchyItem, + Location, + Position, + Range, + SymbolDetailInfo, + SymbolInfo, + SymbolKind, +) + + +def test_inspect_request_defaults(): + """Test InspectRequest default values.""" + req = InspectRequest(locate=Locate(file_path=Path("test.py"), find="func")) + assert req.include_examples == 3 + assert req.include_signature is True + assert req.include_doc is True + assert req.include_call_hierarchy is False + assert req.include_external is False + assert req.context_lines == 2 + + +def test_inspect_request_validation(): + """Test InspectRequest field validation.""" + # Test include_examples range + with pytest.raises(ValidationError): + InspectRequest( + locate=Locate(file_path=Path("test.py"), find="func"), include_examples=-1 + ) + with pytest.raises(ValidationError): + InspectRequest( + locate=Locate(file_path=Path("test.py"), find="func"), include_examples=21 + ) + + # Test context_lines range + with pytest.raises(ValidationError): + InspectRequest( + locate=Locate(file_path=Path("test.py"), find="func"), context_lines=-1 + ) + with pytest.raises(ValidationError): + InspectRequest( + locate=Locate(file_path=Path("test.py"), find="func"), context_lines=11 + ) + + +def test_usage_example_model(): + """Test UsageExample model instantiation.""" + example = UsageExample( + code="func('test')", + location=Location( + file_path=Path("main.py"), + range=Range( + start=Position(line=10, character=5), + end=Position(line=10, character=15), + ), + ), + context=SymbolInfo( + name="caller", + kind=SymbolKind.Function, + file_path=Path("main.py"), + range=Range( + start=Position(line=5, character=1), end=Position(line=15, character=1) + ), + path=["caller"], + ), + call_pattern="func(arg)", + ) + assert example.code == "func('test')" + assert example.call_pattern == "func(arg)" + assert example.context is not None + assert example.context.name == "caller" + + +def test_inspect_response_serialization(): + """Test InspectResponse serialization and deserialization.""" + info = SymbolDetailInfo( + name="my_func", + kind=SymbolKind.Function, + file_path=Path("test.py"), + range=Range( + start=Position(line=1, character=1), end=Position(line=5, character=1) + ), + path=["my_func"], + hover="My function documentation", + ) + + example = UsageExample( + code="my_func()", + location=Location( + file_path=Path("app.py"), + range=Range( + start=Position(line=2, character=1), end=Position(line=2, character=9) + ), + ), + ) + + resp = InspectResponse( + info=info, signature="def my_func() -> None", examples=[example] + ) + + # Serialize to dict + data = resp.model_dump() + assert data["info"]["name"] == "my_func" + assert len(data["examples"]) == 1 + assert data["signature"] == "def my_func() -> None" + + # Deserialize back + resp2 = InspectResponse.model_validate(data) + assert resp2.info.name == "my_func" + assert resp2.signature == "def my_func() -> None" + assert len(resp2.examples) == 1 + + +def test_inspect_response_markdown_rendering(): + """Test InspectResponse.format('markdown') renders correctly.""" + info = SymbolDetailInfo( + name="my_func", + kind=SymbolKind.Function, + file_path=Path("test.py"), + range=Range( + start=Position(line=1, character=1), end=Position(line=5, character=1) + ), + path=["my_func"], + hover="My function documentation", + ) + + example = UsageExample( + code="my_func()", + location=Location( + file_path=Path("app.py"), + range=Range( + start=Position(line=2, character=1), end=Position(line=2, character=9) + ), + ), + context=SymbolInfo( + name="main", + kind=SymbolKind.Function, + file_path=Path("app.py"), + range=Range( + start=Position(line=1, character=1), end=Position(line=10, character=1) + ), + path=["main"], + ), + call_pattern="my_func()", + ) + + call_hierarchy = CallHierarchy( + incoming=[ + CallHierarchyItem( + name="caller_func", + kind=SymbolKind.Function, + file_path=Path("caller.py"), + range=Range( + start=Position(line=5, character=1), + end=Position(line=5, character=10), + ), + ) + ], + outgoing=[], + ) + + resp = InspectResponse( + info=info, + signature="def my_func() -> None", + examples=[example], + call_hierarchy=call_hierarchy, + ) + + markdown = resp.format("markdown") + + assert "# Inspect: `my_func` (`function`)" in markdown + assert "## Signature" in markdown + assert "def my_func() -> None" in markdown + assert "## Documentation" in markdown + assert "My function documentation" in markdown + assert "## Usage Examples" in markdown + assert "### Example 1" in markdown + assert "In `main` (`function`)" in markdown + assert "Pattern: `my_func()`" in markdown + assert "my_func()" in markdown + assert "## Incoming Calls" in markdown + assert "- `caller_func` (`function`)" in markdown + assert "Use these examples to understand" in markdown diff --git a/tests/test_relation.py b/tests/test_relation.py new file mode 100644 index 0000000..2995a4e --- /dev/null +++ b/tests/test_relation.py @@ -0,0 +1,824 @@ +""" +Functional tests for Relation API. + +Tests the call chain discovery capability that answers "how does A reach B?" +""" + +from contextlib import asynccontextmanager +from pathlib import Path + +import pytest +from lsp_client.capability.request import ( + WithRequestCallHierarchy, + WithRequestDocumentSymbol, +) +from lsp_client.client.document_state import DocumentStateManager +from lsp_client.protocol import CapabilityClientProtocol +from lsp_client.protocol.lang import LanguageConfig +from lsp_client.utils.config import ConfigurationMap +from lsp_client.utils.types import AnyPath +from lsp_client.utils.types import Position as LSPPosition +from lsp_client.utils.workspace import ( + DEFAULT_WORKSPACE_DIR, + Workspace, + WorkspaceFolder, +) +from lsprotocol.types import ( + CallHierarchyItem, + CallHierarchyOutgoingCall, + CallHierarchyOutgoingCallsParams, + DocumentSymbol, + LanguageKind, + SymbolKind, +) +from lsprotocol.types import Range as LSPRange + +from lsap.capability.relation import RelationCapability +from lsap.schema.locate import Locate, SymbolScope +from lsap.schema.relation import ChainNode, RelationRequest, RelationResponse + + +def test_chain_node_creation(): + """Test creating a ChainNode.""" + node = ChainNode( + name="handle_request", + kind="Function", + file_path=Path("src/controllers.py"), + ) + assert node.name == "handle_request" + assert node.kind == "Function" + assert node.file_path == Path("src/controllers.py") + assert node.detail is None + + +def test_chain_node_with_detail(): + """Test creating a ChainNode with detail.""" + node = ChainNode( + name="UserService.get_user", + kind="Method", + file_path=Path("src/services/user.py"), + detail="(user_id: int) -> User", + ) + assert node.name == "UserService.get_user" + assert node.kind == "Method" + assert node.detail == "(user_id: int) -> User" + + +def test_relation_request_creation(): + """Test creating a RelationRequest.""" + req = RelationRequest( + source=Locate( + file_path=Path("src/controllers.py"), + scope=SymbolScope(symbol_path=["handle_request"]), + ), + target=Locate( + file_path=Path("src/db.py"), + scope=SymbolScope(symbol_path=["query"]), + ), + ) + assert req.source.file_path == Path("src/controllers.py") + assert req.target.file_path == Path("src/db.py") + assert req.max_depth == 10 # default + + +def test_relation_request_with_custom_max_depth(): + """Test creating a RelationRequest with custom max_depth.""" + req = RelationRequest( + source=Locate( + file_path=Path("src/controllers.py"), + scope=SymbolScope(symbol_path=["handle_request"]), + ), + target=Locate( + file_path=Path("src/db.py"), + scope=SymbolScope(symbol_path=["query"]), + ), + max_depth=5, + ) + assert req.max_depth == 5 + + +def test_relation_response_with_single_chain(): + """Test RelationResponse with a single call chain.""" + source = ChainNode( + name="handle_request", + kind="Function", + file_path=Path("src/controllers.py"), + ) + target = ChainNode( + name="query", + kind="Function", + file_path=Path("src/db.py"), + ) + + chain = [ + source, + ChainNode( + name="UserService.get_user", + kind="Method", + file_path=Path("src/services/user.py"), + ), + target, + ] + + req = RelationRequest( + source=Locate( + file_path=Path("src/controllers.py"), + scope=SymbolScope(symbol_path=["handle_request"]), + ), + target=Locate( + file_path=Path("src/db.py"), + scope=SymbolScope(symbol_path=["query"]), + ), + ) + + resp = RelationResponse( + request=req, + source=source, + target=target, + chains=[chain], + ) + + assert len(resp.chains) == 1 + assert len(resp.chains[0]) == 3 + assert resp.chains[0][0].name == "handle_request" + assert resp.chains[0][1].name == "UserService.get_user" + assert resp.chains[0][2].name == "query" + + +def test_relation_response_with_multiple_chains(): + """Test RelationResponse with multiple call chains (example from docs).""" + source = ChainNode( + name="handle_request", + kind="Function", + file_path=Path("src/controllers.py"), + ) + target = ChainNode( + name="query", + kind="Function", + file_path=Path("src/db.py"), + ) + + # Chain 1: handle_request -> UserService.get_user -> db.query + chain1 = [ + source, + ChainNode( + name="UserService.get_user", + kind="Method", + file_path=Path("src/services/user.py"), + ), + target, + ] + + # Chain 2: handle_request -> AuthService.validate_token -> SessionManager.get_session -> db.query + chain2 = [ + source, + ChainNode( + name="AuthService.validate_token", + kind="Method", + file_path=Path("src/services/auth.py"), + ), + ChainNode( + name="SessionManager.get_session", + kind="Method", + file_path=Path("src/services/session.py"), + ), + target, + ] + + req = RelationRequest( + source=Locate( + file_path=Path("src/controllers.py"), + scope=SymbolScope(symbol_path=["handle_request"]), + ), + target=Locate( + file_path=Path("src/db.py"), + scope=SymbolScope(symbol_path=["query"]), + ), + max_depth=5, + ) + + resp = RelationResponse( + request=req, + source=source, + target=target, + chains=[chain1, chain2], + ) + + assert len(resp.chains) == 2 + assert len(resp.chains[0]) == 3 + assert len(resp.chains[1]) == 4 + assert resp.chains[0][1].name == "UserService.get_user" + assert resp.chains[1][1].name == "AuthService.validate_token" + assert resp.chains[1][2].name == "SessionManager.get_session" + + +def test_relation_response_with_no_chains(): + """Test RelationResponse when no connection found.""" + source = ChainNode( + name="handle_request", + kind="Function", + file_path=Path("src/controllers.py"), + ) + target = ChainNode( + name="unrelated_function", + kind="Function", + file_path=Path("src/utils.py"), + ) + + req = RelationRequest( + source=Locate( + file_path=Path("src/controllers.py"), + scope=SymbolScope(symbol_path=["handle_request"]), + ), + target=Locate( + file_path=Path("src/utils.py"), + scope=SymbolScope(symbol_path=["unrelated_function"]), + ), + ) + + resp = RelationResponse( + request=req, + source=source, + target=target, + chains=[], + ) + + assert len(resp.chains) == 0 + + +def test_relation_response_markdown_format_with_chains(): + """Test markdown formatting of RelationResponse with chains.""" + source = ChainNode( + name="handle_request", + kind="Function", + file_path=Path("src/controllers.py"), + ) + target = ChainNode( + name="query", + kind="Function", + file_path=Path("src/db.py"), + ) + + chain = [ + source, + ChainNode( + name="UserService.get_user", + kind="Method", + file_path=Path("src/services/user.py"), + detail="(user_id: int) -> User", + ), + target, + ] + + req = RelationRequest( + source=Locate( + file_path=Path("src/controllers.py"), + scope=SymbolScope(symbol_path=["handle_request"]), + ), + target=Locate( + file_path=Path("src/db.py"), + scope=SymbolScope(symbol_path=["query"]), + ), + ) + + resp = RelationResponse( + request=req, + source=source, + target=target, + chains=[chain], + ) + + markdown = resp.format() + + # Check that markdown contains key elements + assert "handle_request" in markdown + assert "query" in markdown + assert "Found 1 call chain(s)" in markdown + assert "Chain 1" in markdown + assert "UserService.get_user" in markdown + assert "Method" in markdown + assert "(user_id: int) -> User" in markdown + + +def test_relation_response_markdown_format_no_chains(): + """Test markdown formatting of RelationResponse with no chains.""" + source = ChainNode( + name="handle_request", + kind="Function", + file_path=Path("src/controllers.py"), + ) + target = ChainNode( + name="unrelated_function", + kind="Function", + file_path=Path("src/utils.py"), + ) + + req = RelationRequest( + source=Locate( + file_path=Path("src/controllers.py"), + scope=SymbolScope(symbol_path=["handle_request"]), + ), + target=Locate( + file_path=Path("src/utils.py"), + scope=SymbolScope(symbol_path=["unrelated_function"]), + ), + max_depth=5, + ) + + resp = RelationResponse( + request=req, + source=source, + target=target, + chains=[], + ) + + markdown = resp.format() + + # Check that markdown indicates no connection + assert "handle_request" in markdown + assert "unrelated_function" in markdown + assert "No connection found" in markdown + assert "within depth 5" in markdown + + +def test_chain_node_equality(): + """Test that ChainNodes with same data are equal.""" + node1 = ChainNode( + name="foo", + kind="Function", + file_path=Path("test.py"), + detail="detail", + ) + node2 = ChainNode( + name="foo", + kind="Function", + file_path=Path("test.py"), + detail="detail", + ) + assert node1 == node2 + + +def test_relation_response_chain_order(): + """Test that chains preserve order.""" + source = ChainNode( + name="a", + kind="Function", + file_path=Path("a.py"), + ) + middle1 = ChainNode( + name="b", + kind="Function", + file_path=Path("b.py"), + ) + middle2 = ChainNode( + name="c", + kind="Function", + file_path=Path("c.py"), + ) + target = ChainNode( + name="d", + kind="Function", + file_path=Path("d.py"), + ) + + chain = [source, middle1, middle2, target] + + req = RelationRequest( + source=Locate( + file_path=Path("a.py"), + scope=SymbolScope(symbol_path=["a"]), + ), + target=Locate( + file_path=Path("d.py"), + scope=SymbolScope(symbol_path=["d"]), + ), + ) + + resp = RelationResponse( + request=req, + source=source, + target=target, + chains=[chain], + ) + + # Verify order is preserved + assert resp.chains[0][0].name == "a" + assert resp.chains[0][1].name == "b" + assert resp.chains[0][2].name == "c" + assert resp.chains[0][3].name == "d" + + +def test_relation_request_with_nested_symbol_path(): + """Test RelationRequest with nested symbol paths.""" + req = RelationRequest( + source=Locate( + file_path=Path("src/services/user.py"), + scope=SymbolScope(symbol_path=["UserService", "get_user"]), + ), + target=Locate( + file_path=Path("src/db.py"), + scope=SymbolScope(symbol_path=["Database", "query"]), + ), + max_depth=3, + ) + + assert req.source.scope is not None + assert isinstance(req.source.scope, SymbolScope) + assert req.source.scope.symbol_path == ["UserService", "get_user"] + assert req.target.scope is not None + assert isinstance(req.target.scope, SymbolScope) + assert req.target.scope.symbol_path == ["Database", "query"] + assert req.max_depth == 3 + + +# ============================================================================ +# Integration tests with mock LSP client +# ============================================================================ + + +class MockRelationClient( + WithRequestCallHierarchy, WithRequestDocumentSymbol, CapabilityClientProtocol +): + """Mock client for testing RelationCapability with call hierarchy support.""" + + def __init__(self, call_graph: dict[str, list[str]] | None = None): + """ + Initialize mock client with a call graph. + + call_graph: Dictionary mapping symbol names to list of symbols they call. + Example: {"A": ["B", "C"], "B": ["D"]} means A calls B and C, B calls D. + """ + self.call_graph = call_graph or {} + self._workspace = Workspace( + { + DEFAULT_WORKSPACE_DIR: WorkspaceFolder( + uri=Path.cwd().as_uri(), + name=DEFAULT_WORKSPACE_DIR, + ) + } + ) + self._config_map = ConfigurationMap() + self._doc_state = DocumentStateManager() + + def from_uri(self, uri: str, *, relative: bool = True) -> Path: + return Path(uri.replace("file://", "")) + + def get_workspace(self) -> Workspace: + return self._workspace + + def get_config_map(self) -> ConfigurationMap: + return self._config_map + + def get_document_state(self) -> DocumentStateManager: + return self._doc_state + + @classmethod + def get_language_config(cls): + return LanguageConfig( + kind=LanguageKind.Python, + suffixes=["py"], + project_files=["pyproject.toml"], + ) + + async def request(self, req, schema): + return None + + async def notify(self, msg): + pass + + async def read_file(self, file_path) -> str: + return "# Mock file content" + + async def write_file(self, uri: str, content: str) -> None: + pass + + @asynccontextmanager + async def open_files(self, *file_paths): + yield + + async def request_document_symbol_list( + self, file_path: AnyPath + ) -> list[DocumentSymbol]: + """Mock document symbol list - returns a single function symbol based on file name.""" + path = Path(file_path) + # Extract symbol name from file path (e.g., test_A.py -> A) + name = path.stem.replace("test_", "") + if name in self.call_graph or any( + name in calls for calls in self.call_graph.values() + ): + return [ + DocumentSymbol( + name=name, + kind=SymbolKind.Function, + range=LSPRange( + start=LSPPosition(line=0, character=0), + end=LSPPosition(line=1, character=0), + ), + selection_range=LSPRange( + start=LSPPosition(line=0, character=4), + end=LSPPosition(line=0, character=4 + len(name)), + ), + ) + ] + return [] + + async def request_document_symbol_information_list(self, file_path): + return [] + + def _make_call_hierarchy_item(self, name: str) -> CallHierarchyItem: + """Create a mock CallHierarchyItem for a symbol name.""" + return CallHierarchyItem( + name=name, + kind=SymbolKind.Function, + uri=f"file://test_{name}.py", + range=LSPRange( + start=LSPPosition(line=0, character=0), + end=LSPPosition(line=1, character=0), + ), + selection_range=LSPRange( + start=LSPPosition(line=0, character=4), + end=LSPPosition(line=0, character=4 + len(name)), + ), + ) + + async def prepare_call_hierarchy( + self, file_path: AnyPath, position: LSPPosition + ) -> list[CallHierarchyItem] | None: + """Mock prepare_call_hierarchy - returns item based on file path.""" + path = Path(file_path) + # Extract symbol name from file path (e.g., test_A.py -> A) + name = path.stem.replace("test_", "") + if name in self.call_graph or any( + name in calls for calls in self.call_graph.values() + ): + return [self._make_call_hierarchy_item(name)] + return None + + async def _request_call_hierarchy_outgoing_calls( + self, params: CallHierarchyOutgoingCallsParams + ) -> list[CallHierarchyOutgoingCall] | None: + """Mock outgoing calls based on the call graph.""" + item_name = params.item.name + if item_name not in self.call_graph: + return [] + + outgoing = [] + for target_name in self.call_graph[item_name]: + target_item = self._make_call_hierarchy_item(target_name) + outgoing.append( + CallHierarchyOutgoingCall( + to=target_item, + from_ranges=[ + LSPRange( + start=LSPPosition(line=0, character=0), + end=LSPPosition(line=0, character=1), + ) + ], + ) + ) + return outgoing + + +@pytest.mark.asyncio +async def test_relation_capability_single_path(): + """Test finding a single path between two symbols.""" + # Call graph: A -> B -> C + call_graph = {"A": ["B"], "B": ["C"]} + client = MockRelationClient(call_graph) + capability = RelationCapability(client=client) # type: ignore + + req = RelationRequest( + source=Locate( + file_path=Path("test_A.py"), + scope=SymbolScope(symbol_path=["A"]), + ), + target=Locate( + file_path=Path("test_C.py"), + scope=SymbolScope(symbol_path=["C"]), + ), + max_depth=5, + ) + + resp = await capability(req) + assert resp is not None + assert resp.source.name == "A" + assert resp.target.name == "C" + assert len(resp.chains) == 1 + assert len(resp.chains[0]) == 3 # A -> B -> C + assert resp.chains[0][0].name == "A" + assert resp.chains[0][1].name == "B" + assert resp.chains[0][2].name == "C" + + +@pytest.mark.asyncio +async def test_relation_capability_multiple_paths(): + """Test finding multiple paths between two symbols.""" + # Call graph: A -> B -> D, A -> C -> D (two paths from A to D) + call_graph = {"A": ["B", "C"], "B": ["D"], "C": ["D"]} + client = MockRelationClient(call_graph) + capability = RelationCapability(client=client) # type: ignore + + req = RelationRequest( + source=Locate( + file_path=Path("test_A.py"), + scope=SymbolScope(symbol_path=["A"]), + ), + target=Locate( + file_path=Path("test_D.py"), + scope=SymbolScope(symbol_path=["D"]), + ), + max_depth=5, + ) + + resp = await capability(req) + assert resp is not None + assert resp.source.name == "A" + assert resp.target.name == "D" + assert len(resp.chains) == 2 # Two paths: A->B->D and A->C->D + + # Both chains should start with A and end with D + for chain in resp.chains: + assert chain[0].name == "A" + assert chain[-1].name == "D" + assert len(chain) == 3 # A -> (B or C) -> D + + # Check that we have both paths + middle_nodes = {chain[1].name for chain in resp.chains} + assert middle_nodes == {"B", "C"} + + +@pytest.mark.asyncio +async def test_relation_capability_no_path(): + """Test when no path exists between source and target.""" + # Call graph: A -> B, C -> D (no connection from A to D) + call_graph = {"A": ["B"], "C": ["D"]} + client = MockRelationClient(call_graph) + capability = RelationCapability(client=client) # type: ignore + + req = RelationRequest( + source=Locate( + file_path=Path("test_A.py"), + scope=SymbolScope(symbol_path=["A"]), + ), + target=Locate( + file_path=Path("test_D.py"), + scope=SymbolScope(symbol_path=["D"]), + ), + max_depth=5, + ) + + resp = await capability(req) + assert resp is not None + assert resp.source.name == "A" + assert resp.target.name == "D" + assert len(resp.chains) == 0 + + +@pytest.mark.asyncio +async def test_relation_capability_max_depth(): + """Test that max_depth is respected.""" + # Call graph: A -> B -> C -> D -> E (chain of 4 calls) + call_graph = {"A": ["B"], "B": ["C"], "C": ["D"], "D": ["E"]} + client = MockRelationClient(call_graph) + capability = RelationCapability(client=client) # type: ignore + + # With max_depth=3, we can reach D (3 hops: A->B, B->C, C->D) + req = RelationRequest( + source=Locate( + file_path=Path("test_A.py"), + scope=SymbolScope(symbol_path=["A"]), + ), + target=Locate( + file_path=Path("test_C.py"), + scope=SymbolScope(symbol_path=["C"]), + ), + max_depth=3, + ) + + resp = await capability(req) + assert resp is not None + assert len(resp.chains) == 1 + assert len(resp.chains[0]) == 3 # A -> B -> C + + # With max_depth=2, we can reach C (2 hops: A->B, B->C) + req2 = RelationRequest( + source=Locate( + file_path=Path("test_A.py"), + scope=SymbolScope(symbol_path=["A"]), + ), + target=Locate( + file_path=Path("test_C.py"), + scope=SymbolScope(symbol_path=["C"]), + ), + max_depth=2, + ) + + resp2 = await capability(req2) + assert resp2 is not None + assert len(resp2.chains) == 1 + + # With max_depth=1, we cannot reach C (only B is reachable) + req3 = RelationRequest( + source=Locate( + file_path=Path("test_A.py"), + scope=SymbolScope(symbol_path=["A"]), + ), + target=Locate( + file_path=Path("test_C.py"), + scope=SymbolScope(symbol_path=["C"]), + ), + max_depth=1, + ) + + resp3 = await capability(req3) + assert resp3 is not None + assert len(resp3.chains) == 0 # Cannot reach C within max_depth=1 + + +@pytest.mark.asyncio +async def test_relation_capability_cycle_detection(): + """Test that cycles are properly detected and don't cause infinite loops.""" + # Call graph with cycle: A -> B -> C -> B (cycle between B and C) + # But there's also A -> B -> D (path without cycle) + call_graph = {"A": ["B"], "B": ["C", "D"], "C": ["B"]} + client = MockRelationClient(call_graph) + capability = RelationCapability(client=client) # type: ignore + + req = RelationRequest( + source=Locate( + file_path=Path("test_A.py"), + scope=SymbolScope(symbol_path=["A"]), + ), + target=Locate( + file_path=Path("test_D.py"), + scope=SymbolScope(symbol_path=["D"]), + ), + max_depth=10, + ) + + resp = await capability(req) + assert resp is not None + assert resp.source.name == "A" + assert resp.target.name == "D" + assert len(resp.chains) == 1 # Should find A -> B -> D + assert len(resp.chains[0]) == 3 + assert resp.chains[0][0].name == "A" + assert resp.chains[0][1].name == "B" + assert resp.chains[0][2].name == "D" + + +@pytest.mark.asyncio +async def test_relation_capability_direct_call(): + """Test direct call (source directly calls target).""" + # Call graph: A -> B (direct call) + call_graph = {"A": ["B"]} + client = MockRelationClient(call_graph) + capability = RelationCapability(client=client) # type: ignore + + req = RelationRequest( + source=Locate( + file_path=Path("test_A.py"), + scope=SymbolScope(symbol_path=["A"]), + ), + target=Locate( + file_path=Path("test_B.py"), + scope=SymbolScope(symbol_path=["B"]), + ), + max_depth=5, + ) + + resp = await capability(req) + assert resp is not None + assert len(resp.chains) == 1 + assert len(resp.chains[0]) == 2 # Just A -> B + assert resp.chains[0][0].name == "A" + assert resp.chains[0][1].name == "B" + + +@pytest.mark.asyncio +async def test_relation_capability_different_path_lengths(): + """Test finding paths of different lengths.""" + # Call graph: A -> D (direct), A -> B -> D (2 hops), A -> C -> E -> D (3 hops) + call_graph = {"A": ["B", "C", "D"], "B": ["D"], "C": ["E"], "E": ["D"]} + client = MockRelationClient(call_graph) + capability = RelationCapability(client=client) # type: ignore + + req = RelationRequest( + source=Locate( + file_path=Path("test_A.py"), + scope=SymbolScope(symbol_path=["A"]), + ), + target=Locate( + file_path=Path("test_D.py"), + scope=SymbolScope(symbol_path=["D"]), + ), + max_depth=5, + ) + + resp = await capability(req) + assert resp is not None + assert len(resp.chains) == 3 # Three paths of different lengths + + # Check path lengths + path_lengths = sorted([len(chain) for chain in resp.chains]) + assert path_lengths == [2, 3, 4] # Direct, 2-hop, and 3-hop paths diff --git a/tests/test_relation_e2e.py b/tests/test_relation_e2e.py new file mode 100644 index 0000000..8636a00 --- /dev/null +++ b/tests/test_relation_e2e.py @@ -0,0 +1,535 @@ +""" +End-to-end tests for RelationCapability using BasedpyrightClient. + +These tests verify the RelationCapability's ability to find call chains +between functions in real Python code using a real Language Server. +""" + +from __future__ import annotations + +import anyio +import pytest +from lsp_client.clients.basedpyright import BasedpyrightClient + +from lsap.capability.relation import RelationCapability +from lsap.schema.locate import Locate, SymbolScope +from lsap.schema.relation import RelationRequest + +from .framework.lsp import lsp_interaction_context + + +@pytest.mark.e2e +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_relation_direct_call(): + """Test finding a direct call relationship: A -> B.""" + async with lsp_interaction_context(BasedpyrightClient) as interaction: # type: ignore + # Create file with direct call + content = ''' +def caller(): + """The calling function.""" + return callee() + +def callee(): + """The called function.""" + return 42 +''' + file_path = await interaction.create_file("direct_call.py", content) + + # Wait for LSP indexing + await anyio.sleep(1) + + # Create RelationCapability + capability = RelationCapability(client=interaction.client) # type: ignore + + # Request to find path from caller to callee + req = RelationRequest( + source=Locate( + file_path=file_path, + scope=SymbolScope(symbol_path=["caller"]), + ), + target=Locate( + file_path=file_path, + scope=SymbolScope(symbol_path=["callee"]), + ), + max_depth=5, + ) + + resp = await capability(req) + + # Verify response + assert resp is not None, "Response should not be None" + assert resp.source.name == "caller" + assert resp.target.name == "callee" + assert len(resp.chains) == 1, f"Expected 1 chain, got {len(resp.chains)}" + assert len(resp.chains[0]) == 2, "Chain should have 2 nodes: caller -> callee" + assert resp.chains[0][0].name == "caller" + assert resp.chains[0][1].name == "callee" + + # Test markdown formatting + markdown = resp.format() + assert "caller" in markdown + assert "callee" in markdown + assert "Found 1 call chain(s)" in markdown + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_relation_indirect_call(): + """Test finding an indirect call relationship: A -> B -> C.""" + async with lsp_interaction_context(BasedpyrightClient) as interaction: # type: ignore + content = ''' +def entry_point(): + """Entry point function.""" + return middle_layer() + +def middle_layer(): + """Middle layer function.""" + return final_destination() + +def final_destination(): + """Final destination function.""" + return "success" +''' + file_path = await interaction.create_file("indirect_call.py", content) + + await anyio.sleep(1) + + capability = RelationCapability(client=interaction.client) # type: ignore + + req = RelationRequest( + source=Locate( + file_path=file_path, + scope=SymbolScope(symbol_path=["entry_point"]), + ), + target=Locate( + file_path=file_path, + scope=SymbolScope(symbol_path=["final_destination"]), + ), + max_depth=10, + ) + + resp = await capability(req) + + assert resp is not None + assert resp.source.name == "entry_point" + assert resp.target.name == "final_destination" + assert len(resp.chains) == 1 + assert len(resp.chains[0]) == 3, "Chain should have 3 nodes" + assert resp.chains[0][0].name == "entry_point" + assert resp.chains[0][1].name == "middle_layer" + assert resp.chains[0][2].name == "final_destination" + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_relation_multiple_paths(): + """Test finding multiple paths between two functions.""" + async with lsp_interaction_context(BasedpyrightClient) as interaction: # type: ignore + content = ''' +def source(): + """Source function with multiple call paths.""" + path_a() + path_b() + return "done" + +def path_a(): + """First path to target.""" + return target() + +def path_b(): + """Second path to target.""" + return target() + +def target(): + """Target function.""" + return 100 +''' + file_path = await interaction.create_file("multiple_paths.py", content) + + await anyio.sleep(1) + + capability = RelationCapability(client=interaction.client) # type: ignore + + req = RelationRequest( + source=Locate( + file_path=file_path, + scope=SymbolScope(symbol_path=["source"]), + ), + target=Locate( + file_path=file_path, + scope=SymbolScope(symbol_path=["target"]), + ), + max_depth=5, + ) + + resp = await capability(req) + + assert resp is not None + assert resp.source.name == "source" + assert resp.target.name == "target" + assert len(resp.chains) == 2, f"Expected 2 paths, got {len(resp.chains)}" + + # Verify both paths exist + middle_nodes = {chain[1].name for chain in resp.chains} + assert middle_nodes == {"path_a", "path_b"} + + # Verify all chains have correct structure + for chain in resp.chains: + assert len(chain) == 3 + assert chain[0].name == "source" + assert chain[2].name == "target" + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_relation_no_connection(): + """Test when there's no connection between functions.""" + async with lsp_interaction_context(BasedpyrightClient) as interaction: # type: ignore + content = ''' +def isolated_a(): + """Isolated function A.""" + return 1 + +def isolated_b(): + """Isolated function B.""" + return 2 +''' + file_path = await interaction.create_file("no_connection.py", content) + + await anyio.sleep(1) + + capability = RelationCapability(client=interaction.client) # type: ignore + + req = RelationRequest( + source=Locate( + file_path=file_path, + scope=SymbolScope(symbol_path=["isolated_a"]), + ), + target=Locate( + file_path=file_path, + scope=SymbolScope(symbol_path=["isolated_b"]), + ), + max_depth=5, + ) + + resp = await capability(req) + + assert resp is not None + assert resp.source.name == "isolated_a" + assert resp.target.name == "isolated_b" + assert len(resp.chains) == 0, "Should find no chains" + + # Verify markdown shows no connection + markdown = resp.format() + assert "No connection found" in markdown + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_relation_max_depth_limit(): + """Test that max_depth is respected.""" + async with lsp_interaction_context(BasedpyrightClient) as interaction: # type: ignore + content = ''' +def level_0(): + """Level 0.""" + return level_1() + +def level_1(): + """Level 1.""" + return level_2() + +def level_2(): + """Level 2.""" + return level_3() + +def level_3(): + """Level 3.""" + return "deep" +''' + file_path = await interaction.create_file("max_depth.py", content) + + await anyio.sleep(1) + + capability = RelationCapability(client=interaction.client) # type: ignore + + # With max_depth=2, can reach level_2 + req = RelationRequest( + source=Locate( + file_path=file_path, + scope=SymbolScope(symbol_path=["level_0"]), + ), + target=Locate( + file_path=file_path, + scope=SymbolScope(symbol_path=["level_2"]), + ), + max_depth=2, + ) + + resp = await capability(req) + assert resp is not None + assert len(resp.chains) == 1 + assert len(resp.chains[0]) == 3 # level_0 -> level_1 -> level_2 + + # With max_depth=1, cannot reach level_2 + req_shallow = RelationRequest( + source=Locate( + file_path=file_path, + scope=SymbolScope(symbol_path=["level_0"]), + ), + target=Locate( + file_path=file_path, + scope=SymbolScope(symbol_path=["level_2"]), + ), + max_depth=1, + ) + + resp_shallow = await capability(req_shallow) + assert resp_shallow is not None + assert len(resp_shallow.chains) == 0, "Should not find path within max_depth=1" + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_relation_with_class_methods(): + """Test finding call chains involving class methods.""" + async with lsp_interaction_context(BasedpyrightClient) as interaction: # type: ignore + content = ''' +class Calculator: + """A simple calculator class.""" + + def add(self, a: int, b: int) -> int: + """Add two numbers.""" + return self._internal_add(a, b) + + def _internal_add(self, a: int, b: int) -> int: + """Internal addition method.""" + return a + b + +def use_calculator(): + """Use the calculator.""" + calc = Calculator() + return calc.add(1, 2) +''' + file_path = await interaction.create_file("class_methods.py", content) + + await anyio.sleep(1.5) + + capability = RelationCapability(client=interaction.client) # type: ignore + + # Find path from use_calculator to _internal_add + req = RelationRequest( + source=Locate( + file_path=file_path, + scope=SymbolScope(symbol_path=["use_calculator"]), + ), + target=Locate( + file_path=file_path, + scope=SymbolScope(symbol_path=["Calculator", "_internal_add"]), + ), + max_depth=5, + ) + + resp = await capability(req) + + assert resp is not None + assert resp.source.name == "use_calculator" + assert resp.target.name == "_internal_add" + # Should find path: use_calculator -> Calculator.add -> Calculator._internal_add + if len(resp.chains) > 0: + assert resp.chains[0][0].name == "use_calculator" + assert resp.chains[0][-1].name == "_internal_add" + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_relation_cross_file(): + """Test finding call chains across multiple files.""" + async with lsp_interaction_context(BasedpyrightClient) as interaction: # type: ignore + # Create module with utility function + util_content = ''' +def utility_function(): + """A utility function.""" + return "utility" +''' + util_path = await interaction.create_file("utils.py", util_content) + + # Create main file that imports and calls utility + main_content = ''' +from utils import utility_function + +def main(): + """Main function.""" + return utility_function() +''' + main_path = await interaction.create_file("main.py", main_content) + + await anyio.sleep(2) + + capability = RelationCapability(client=interaction.client) # type: ignore + + # Find path from main to utility_function + req = RelationRequest( + source=Locate( + file_path=main_path, + scope=SymbolScope(symbol_path=["main"]), + ), + target=Locate( + file_path=util_path, + scope=SymbolScope(symbol_path=["utility_function"]), + ), + max_depth=5, + ) + + resp = await capability(req) + + assert resp is not None + assert resp.source.name == "main" + assert resp.target.name == "utility_function" + # Should find direct call: main -> utility_function + if len(resp.chains) > 0: + assert len(resp.chains[0]) == 2 + assert resp.chains[0][0].name == "main" + assert resp.chains[0][1].name == "utility_function" + # Verify different files + assert resp.chains[0][0].file_path != resp.chains[0][1].file_path + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_relation_invalid_source(): + """Test when source symbol cannot be found.""" + async with lsp_interaction_context(BasedpyrightClient) as interaction: # type: ignore + content = ''' +def valid_function(): + """A valid function.""" + return 42 +''' + file_path = await interaction.create_file("invalid_source.py", content) + + await anyio.sleep(1) + + capability = RelationCapability(client=interaction.client) # type: ignore + + req = RelationRequest( + source=Locate( + file_path=file_path, + scope=SymbolScope(symbol_path=["nonexistent_function"]), + ), + target=Locate( + file_path=file_path, + scope=SymbolScope(symbol_path=["valid_function"]), + ), + max_depth=5, + ) + + resp = await capability(req) + + # Should return None when source cannot be located + assert resp is None + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_relation_invalid_target(): + """Test when target symbol cannot be found.""" + async with lsp_interaction_context(BasedpyrightClient) as interaction: # type: ignore + content = ''' +def valid_function(): + """A valid function.""" + return 42 +''' + file_path = await interaction.create_file("invalid_target.py", content) + + await anyio.sleep(1) + + capability = RelationCapability(client=interaction.client) # type: ignore + + req = RelationRequest( + source=Locate( + file_path=file_path, + scope=SymbolScope(symbol_path=["valid_function"]), + ), + target=Locate( + file_path=file_path, + scope=SymbolScope(symbol_path=["nonexistent_function"]), + ), + max_depth=5, + ) + + resp = await capability(req) + + # Should return None when target cannot be located + assert resp is None + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_relation_different_path_lengths(): + """Test finding paths of different lengths to the same target.""" + async with lsp_interaction_context(BasedpyrightClient) as interaction: # type: ignore + content = ''' +def source(): + """Source with multiple paths of different lengths.""" + direct() + path1() + path2() + +def direct(): + """Direct path to target.""" + return target() + +def path1(): + """One hop to target.""" + return path1_helper() + +def path1_helper(): + """Helper for path1.""" + return target() + +def path2(): + """Two hops to target.""" + return path2_a() + +def path2_a(): + """First hop in path2.""" + return path2_b() + +def path2_b(): + """Second hop in path2.""" + return target() + +def target(): + """The target function.""" + return 100 +''' + file_path = await interaction.create_file("different_lengths.py", content) + + await anyio.sleep(1.5) + + capability = RelationCapability(client=interaction.client) # type: ignore + + req = RelationRequest( + source=Locate( + file_path=file_path, + scope=SymbolScope(symbol_path=["source"]), + ), + target=Locate( + file_path=file_path, + scope=SymbolScope(symbol_path=["target"]), + ), + max_depth=10, + ) + + resp = await capability(req) + + assert resp is not None + assert resp.source.name == "source" + assert resp.target.name == "target" + # Should find multiple paths + assert len(resp.chains) >= 1, "Should find at least one path" + + # Verify we have paths of different lengths + path_lengths = sorted([len(chain) for chain in resp.chains]) + # Should have paths like: source->direct->target (3), source->path1->path1_helper->target (4), etc. + assert len(set(path_lengths)) > 1, "Should have paths of different lengths" diff --git a/tests/test_rename_glob.py b/tests/test_rename_glob.py index ceeda0d..abb5786 100644 --- a/tests/test_rename_glob.py +++ b/tests/test_rename_glob.py @@ -15,12 +15,13 @@ from lsprotocol.types import ( Range as LSPRange, ) -from test_rename_e2e import E2ERenameClient from lsap.capability.rename import RenameExecuteCapability, RenamePreviewCapability from lsap.schema.locate import Locate from lsap.schema.rename import RenameExecuteRequest, RenamePreviewRequest +from .test_rename_e2e import E2ERenameClient + class GlobPatternRenameClient(E2ERenameClient): """Client that returns edits for multiple files with different paths."""