Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/shared.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ jobs:
- name: Run pytest with coverage
shell: bash
run: |
uv run --frozen --no-sync coverage run -m pytest
uv run --frozen --no-sync coverage run -m pytest -n auto
uv run --frozen --no-sync coverage combine
uv run --frozen --no-sync coverage report
Expand Down
3 changes: 3 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ This document contains critical information about working with this codebase. Fo
- Functions must be focused and small
- Follow existing patterns exactly
- Line length: 120 chars maximum
- FORBIDDEN: imports inside functions

3. Testing Requirements
- Framework: `uv run --frozen pytest`
Expand All @@ -25,6 +26,8 @@ This document contains critical information about working with this codebase. Fo
- Coverage: test edge cases and errors
- New features require tests
- Bug fixes require regression tests
- IMPORTANT: The `tests/client/test_client.py` is the most well designed test file. Follow its patterns.
- IMPORTANT: Be minimal, and focus on E2E tests: Use the `mcp.client.Client` whenever possible.

- For commits fixing bugs or adding features based on user reports add:

Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ xfail_strict = true
addopts = """
--color=yes
--capture=fd
--numprocesses auto
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually forgot to remove this when I added the scripts/test script.

You need to be able to run pytest without -n auto, because you may want to run a single test, and you don't need multiple cores for it.

"""
filterwarnings = [
"error",
Expand Down
7 changes: 1 addition & 6 deletions src/mcp/client/_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,7 @@ class InMemoryTransport:
result = await client.call_tool("my_tool", {...})
"""

def __init__(
self,
server: Server[Any] | FastMCP,
*,
raise_exceptions: bool = False,
) -> None:
def __init__(self, server: Server[Any] | FastMCP, *, raise_exceptions: bool = False) -> None:
"""Initialize the in-memory transport.
Args:
Expand Down
61 changes: 19 additions & 42 deletions src/mcp/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(
self,
server: Server[Any] | FastMCP,
*,
# TODO(Marcelo): When do `raise_exceptions=True` actually raises?
raise_exceptions: bool = False,
Comment on lines +61 to 62
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't test this. What it is supposed to raise?

read_timeout_seconds: float | None = None,
sampling_callback: SamplingFnT | None = None,
Expand Down Expand Up @@ -125,14 +126,9 @@ async def __aenter__(self) -> Client:
self._exit_stack = exit_stack.pop_all()
return self

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: Any,
) -> None:
async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any) -> None:
"""Exit the async context manager."""
if self._exit_stack:
if self._exit_stack: # pragma: no branch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why a new pragma?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because I don't think it's necessary to test this. If you can suggest a cute test, please push it.

await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb)
self._session = None

Expand Down Expand Up @@ -177,28 +173,22 @@ async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResul
"""Set the logging level on the server."""
return await self.session.set_logging_level(level)

async def list_resources(
self,
params: types.PaginatedRequestParams | None = None,
) -> types.ListResourcesResult:
async def list_resources(self, *, cursor: str | None = None) -> types.ListResourcesResult:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR started from this. I understand that in the typescript world this is the best practice, but it's not in Python.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will need to update migration docs

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Client is a new class, it doesn't need any migration docs.

"""List available resources from the server."""
return await self.session.list_resources(params=params)
return await self.session.list_resources(params=types.PaginatedRequestParams(cursor=cursor))

async def list_resource_templates(
self,
params: types.PaginatedRequestParams | None = None,
) -> types.ListResourceTemplatesResult:
async def list_resource_templates(self, *, cursor: str | None = None) -> types.ListResourceTemplatesResult:
"""List available resource templates from the server."""
return await self.session.list_resource_templates(params=params)
return await self.session.list_resource_templates(params=types.PaginatedRequestParams(cursor=cursor))

async def read_resource(self, uri: str | AnyUrl) -> types.ReadResourceResult:
"""Read a resource from the server.

Args:
uri: The URI of the resource to read
uri: The URI of the resource to read.

Returns:
The resource content
The resource content.
"""
return await self.session.read_resource(uri)

Expand Down Expand Up @@ -239,26 +229,19 @@ async def call_tool(
meta=meta,
)

async def list_prompts(
self,
params: types.PaginatedRequestParams | None = None,
) -> types.ListPromptsResult:
async def list_prompts(self, *, cursor: str | None = None) -> types.ListPromptsResult:
"""List available prompts from the server."""
return await self.session.list_prompts(params=params)
return await self.session.list_prompts(params=types.PaginatedRequestParams(cursor=cursor))

async def get_prompt(
self,
name: str,
arguments: dict[str, str] | None = None,
) -> types.GetPromptResult:
async def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult:
"""Get a prompt from the server.

Args:
name: The name of the prompt
arguments: Arguments to pass to the prompt

Returns:
The prompt content
The prompt content.
"""
return await self.session.get_prompt(name=name, arguments=arguments)

Expand All @@ -276,21 +259,15 @@ async def complete(
context_arguments: Additional context arguments

Returns:
Completion suggestions
Completion suggestions.
"""
return await self.session.complete(
ref=ref,
argument=argument,
context_arguments=context_arguments,
)
return await self.session.complete(ref=ref, argument=argument, context_arguments=context_arguments)

async def list_tools(
self,
params: types.PaginatedRequestParams | None = None,
) -> types.ListToolsResult:
async def list_tools(self, *, cursor: str | None = None) -> types.ListToolsResult:
"""List available tools from the server."""
return await self.session.list_tools(params=params)
return await self.session.list_tools(params=types.PaginatedRequestParams(cursor=cursor))

async def send_roots_list_changed(self) -> None:
"""Send a notification that the roots list has changed."""
await self.session.send_roots_list_changed()
# TODO(Marcelo): Currently, there is no way for the server to handle this. We should add support.
await self.session.send_roots_list_changed() # pragma: no cover
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is there a new pragma?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the test was not testing anything, and the server actually doesn't do anything with this notification. We need to solve the TODO, which will result in dropping the pragma.

5 changes: 1 addition & 4 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,10 +349,7 @@ async def list_prompts(self, *, params: types.PaginatedRequestParams | None = No
Args:
params: Full pagination parameters including cursor and any future fields
"""
return await self.send_request(
types.ListPromptsRequest(params=params),
types.ListPromptsResult,
)
return await self.send_request(types.ListPromptsRequest(params=params), types.ListPromptsResult)

async def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult:
"""Send a prompts/get request."""
Expand Down
40 changes: 12 additions & 28 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,12 +234,12 @@ async def send_request(
metadata: MessageMetadata = None,
progress_callback: ProgressFnT | None = None,
) -> ReceiveResultT:
"""Sends a request and wait for a response. Raises an McpError if the
response contains an error. If a request read timeout is provided, it
will take precedence over the session read timeout.
"""Sends a request and wait for a response.

Do not use this method to emit notifications! Use send_notification()
instead.
Raises an McpError if the response contains an error. If a request read timeout is provided, it will take
precedence over the session read timeout.

Do not use this method to emit notifications! Use send_notification() instead.
"""
request_id = self._request_id
self._request_id = request_id + 1
Expand All @@ -261,15 +261,10 @@ async def send_request(

try:
jsonrpc_request = JSONRPCRequest(jsonrpc="2.0", id=request_id, **request_data)

await self._write_stream.send(SessionMessage(message=jsonrpc_request, metadata=metadata))

# request read timeout takes precedence over session read timeout
timeout = None
if request_read_timeout_seconds is not None: # pragma: no cover
timeout = request_read_timeout_seconds
elif self._session_read_timeout_seconds is not None: # pragma: no cover
timeout = self._session_read_timeout_seconds
timeout = request_read_timeout_seconds or self._session_read_timeout_seconds
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've simplified this.


try:
with anyio.fail_after(timeout):
Expand All @@ -279,9 +274,8 @@ async def send_request(
ErrorData(
code=httpx.codes.REQUEST_TIMEOUT,
message=(
f"Timed out while waiting for response to "
f"{request.__class__.__name__}. Waited "
f"{timeout} seconds."
f"Timed out while waiting for response to {request.__class__.__name__}. "
f"Waited {timeout} seconds."
),
)
)
Expand All @@ -302,9 +296,7 @@ async def send_notification(
notification: SendNotificationT,
related_request_id: RequestId | None = None,
) -> None:
"""Emits a notification, which is a one-way message that does not expect
a response.
"""
"""Emits a notification, which is a one-way message that does not expect a response."""
# Some transport implementations may need to set the related_request_id
# to attribute to the notifications to the request that triggered them.
jsonrpc_notification = JSONRPCNotification(
Expand Down Expand Up @@ -373,11 +365,7 @@ async def _receive_loop(self) -> None:
error_response = JSONRPCError(
jsonrpc="2.0",
id=message.message.id,
error=ErrorData(
code=INVALID_PARAMS,
message="Invalid request parameters",
data="",
),
error=ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data=""),
)
session_message = SessionMessage(message=error_response)
await self._write_stream.send(session_message)
Expand Down Expand Up @@ -518,13 +506,9 @@ async def send_progress_notification(
total: float | None = None,
message: str | None = None,
) -> None:
"""Sends a progress notification for a request that is currently being
processed.
"""
"""Sends a progress notification for a request that is currently being processed."""

async def _handle_incoming(
self,
req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception,
self, req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception
) -> None:
"""A generic handler for incoming messages. Overwritten by subclasses."""
pass # pragma: no cover
Loading