Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 155 additions & 96 deletions src/geminimcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,41 +3,81 @@
from __future__ import annotations

import json
import os
import queue
import subprocess
import threading
import time
import uuid
from pathlib import Path
from typing import Annotated, Any, Dict, Generator, List, Literal, Optional
from typing import Annotated, Any, Dict, Generator, Optional

from mcp.server.fastmcp import FastMCP
from pydantic import BeforeValidator, Field
from pydantic import Field
import shutil

mcp = FastMCP("Gemini MCP Server-from guda.studio")

# Default timeout for gemini CLI execution (5 minutes)
DEFAULT_TIMEOUT = 300

def run_shell_command(cmd: list[str], cwd: str | None = None) -> Generator[str, None, None]:

class GeminiAuthError(Exception):
"""Raised when Gemini CLI is not authenticated."""
pass


class GeminiTimeoutError(Exception):
"""Raised when Gemini CLI execution times out."""
pass


def check_gemini_auth() -> tuple[bool, str]:
"""Check if Gemini CLI is authenticated.

Returns:
Tuple of (is_authenticated, status_message)
"""
try:
# Try a quick command to check auth status
result = subprocess.run(
["gemini", "--version"],
capture_output=True,
timeout=10,
text=True,
)
# If gemini runs without prompting for auth, we're good
# The actual auth check happens when we try to use it
return (result.returncode == 0, result.stdout.strip())
except subprocess.TimeoutExpired:
return (False, "Gemini CLI timed out during auth check")
except FileNotFoundError:
return (False, "Gemini CLI not found in PATH")
except Exception as e:
return (False, f"Auth check failed: {e}")


def run_shell_command(
cmd: list[str],
cwd: str | None = None,
timeout: int = DEFAULT_TIMEOUT,
) -> Generator[str, None, None]:
"""Execute a command and stream its output line-by-line.

Args:
cmd: Command and arguments as a list (e.g., ["gemini", "-o", "stream-json", "--", "prompt"])
cmd: Command and arguments as a list
cwd: Working directory for the command
timeout: Maximum execution time in seconds

Yields:
Output lines from the command

Raises:
GeminiTimeoutError: If execution exceeds timeout
"""
popen_cmd = cmd
popen_cmd = cmd.copy() # Don't mutate the original list

gemini_path = shutil.which("gemini") or cmd[0]
popen_cmd[0] = gemini_path

# if os.name == "nt" and gemini_path.lower().endswith((".cmd", ".bat")):
# from subprocess import list2cmdline
# popen_cmd = ["cmd.exe", "/s", "/c", list2cmdline(cmd)]

process = subprocess.Popen(
popen_cmd,
shell=False,
Expand All @@ -49,8 +89,9 @@ def run_shell_command(cmd: list[str], cwd: str | None = None) -> Generator[str,
cwd=cwd,
)

output_queue: queue.Queue[str | None] = queue.Queue()
output_queue: queue.Queue[str | None] = queue.Queue(maxsize=10000) # Bounded queue
GRACEFUL_SHUTDOWN_DELAY = 0.3
start_time = time.time()

def is_turn_completed(line: str) -> bool:
"""Check if the line indicates turn completion via JSON parsing."""
Expand All @@ -63,21 +104,40 @@ def is_turn_completed(line: str) -> bool:
def read_output() -> None:
"""Read process output in a separate thread."""
if process.stdout:
for line in iter(process.stdout.readline, ""):
stripped = line.strip()
output_queue.put(stripped)
if is_turn_completed(stripped):
time.sleep(GRACEFUL_SHUTDOWN_DELAY)
process.terminate()
break
process.stdout.close()
try:
for line in iter(process.stdout.readline, ""):
stripped = line.strip()
try:
output_queue.put(stripped, timeout=1)
except queue.Full:
# Queue is full, skip this line to prevent memory issues
continue
if is_turn_completed(stripped):
time.sleep(GRACEFUL_SHUTDOWN_DELAY)
process.terminate()
break
finally:
process.stdout.close()
output_queue.put(None)

thread = threading.Thread(target=read_output)
thread = threading.Thread(target=read_output, daemon=True)
thread.start()

# Yield lines while process is running
# Yield lines while process is running, with timeout check
while True:
# Check for timeout
elapsed = time.time() - start_time
if elapsed > timeout:
process.kill()
try:
process.wait(timeout=2)
except subprocess.TimeoutExpired:
pass
raise GeminiTimeoutError(
f"Gemini CLI execution timed out after {timeout} seconds. "
"This may indicate the CLI is waiting for authentication or is stuck."
)

try:
line = output_queue.get(timeout=0.5)
if line is None:
Expand All @@ -87,13 +147,23 @@ def read_output() -> None:
if process.poll() is not None and not thread.is_alive():
break

# Cleanup
try:
process.wait(timeout=5)
except subprocess.TimeoutExpired:
process.kill()
process.wait()
try:
process.wait(timeout=2)
except subprocess.TimeoutExpired:
pass # Process is truly stuck, nothing we can do

# Wait for thread with timeout
thread.join(timeout=5)
if thread.is_alive():
# Thread is stuck, but it's a daemon thread so it will be cleaned up on exit
pass

# Drain remaining queue items
while not output_queue.empty():
try:
line = output_queue.get_nowait()
Expand All @@ -103,35 +173,11 @@ def read_output() -> None:
break


def windows_escape(prompt):
"""
Windows 风格的字符串转义函数。
把常见特殊字符转义成 \\ 形式,适合命令行、JSON 或路径使用。
比如:\n 变成 \\n," 变成 \\"。
"""
# 先处理反斜杠,避免它干扰其他替换
result = prompt.replace("\\", "\\\\")
# 双引号,转义成 \",防止字符串边界乱套
result = result.replace('"', '\\"')
# 换行符,Windows 常用 \r\n,但我们分开转义
result = result.replace("\n", "\\n")
result = result.replace("\r", "\\r")
# 制表符,空格的“超级版”
result = result.replace("\t", "\\t")
# 其他常见:退格符(像按了后退键)、换页符(打印机跳页用)
result = result.replace("\b", "\\b")
result = result.replace("\f", "\\f")
# 如果有单引号,也转义下(不过 Windows 命令行不那么严格,但保险起见)
result = result.replace("'", "\\'")

return result


@mcp.tool(
name="gemini",
description="""
Invokes the Gemini CLI to execute AI-driven tasks, returning structured JSON events and a session identifier for conversation continuity.
Invokes the Gemini CLI to execute AI-driven tasks, returning structured JSON events and a session identifier for conversation continuity.

**Return structure:**
- `success`: boolean indicating execution status
- `SESSION_ID`: unique identifier for resuming this conversation in future calls
Expand All @@ -145,7 +191,7 @@ def windows_escape(prompt):
- Use `return_all_messages` only when detailed execution traces are necessary (increases payload size)
- Only pass `model` when the user has explicitly requested a specific model
""",
meta={"version": "0.0.0", "author": "guda.studio"},
meta={"version": "0.1.1", "author": "guda.studio"},
)
async def gemini(
PROMPT: Annotated[str, "Instruction for the task to send to gemini."],
Expand All @@ -166,18 +212,28 @@ async def gemini(
str,
"The model to use for the gemini session. This parameter is strictly prohibited unless explicitly specified by the user.",
] = "",
timeout: Annotated[
int,
Field(description="Maximum execution time in seconds. Defaults to 300 (5 minutes)."),
] = DEFAULT_TIMEOUT,
) -> Dict[str, Any]:
"""Execute a gemini CLI session and return the results."""


# Check if workspace exists
if not cd.exists():
success = False
err_message = f"The workspace root directory `{cd.absolute().as_posix()}` does not exist. Please check the path and try again."
return {"success": success, "error": err_message}
return {
"success": False,
"error": f"The workspace root directory `{cd.absolute().as_posix()}` does not exist."
}

if os.name == "nt":
PROMPT = windows_escape(PROMPT)
else:
PROMPT = PROMPT
# Pre-flight auth check
auth_ok, auth_msg = check_gemini_auth()
if not auth_ok:
return {
"success": False,
"error": f"Gemini CLI authentication check failed: {auth_msg}. "
"Please run 'gemini auth login' manually in your terminal."
}

cmd = ["gemini", "--prompt", PROMPT, "-o", "stream-json"]

Expand All @@ -190,66 +246,69 @@ async def gemini(
if SESSION_ID:
cmd.extend(["--resume", SESSION_ID])

all_messages = []
all_messages: list[dict] = []
agent_messages = ""
success = True
err_message = ""
thread_id: Optional[str] = None

for line in run_shell_command(cmd, cwd=cd.absolute().as_posix()):
try:
line_dict = json.loads(line.strip())
all_messages.append(line_dict)
item_type = line_dict.get("type", "")
item_role = line_dict.get("role", "")
if item_type == "message" and item_role == "assistant":
if (
"The --prompt (-p) flag has been deprecated and will be removed in a future version. Please use a positional argument for your prompt. See gemini --help for more information.\n"
in line_dict.get("content", "")
):
continue
agent_messages = agent_messages + line_dict.get("content", "")
if line_dict.get("session_id") is not None:
thread_id = line_dict.get("session_id")
# if "fail" in line_dict.get("type", ""):
# success = False
# err_message = "gemini error: " + line_dict.get("error", {}).get("message", "")
# break
# if "error" in line_dict.get("type", ""):
# success = False
# err_message = "gemini error: " + line_dict.get("message", "")
except json.JSONDecodeError as error:
# Improved error handling: include problematic line
err_message += "\n\n[json decode error] " + line
continue
except Exception as error:
err_message += "\n\n[unexpected error] " + f"Unexpected error: {error}. Line: {line!r}"
break
try:
for line in run_shell_command(cmd, cwd=cd.absolute().as_posix(), timeout=timeout):
try:
line_dict = json.loads(line.strip())
all_messages.append(line_dict)
item_type = line_dict.get("type", "")
item_role = line_dict.get("role", "")

if item_type == "message" and item_role == "assistant":
content = line_dict.get("content", "")
# Skip deprecation warnings
if "The --prompt (-p) flag has been deprecated" in content:
continue
agent_messages += content

if line_dict.get("session_id") is not None:
thread_id = line_dict.get("session_id")

except json.JSONDecodeError:
# Limit error message accumulation
if len(err_message) < 2000:
err_message += f"\n[json decode error] {line[:200]}"
continue
except Exception as error:
err_message += f"\n[unexpected error] {error}"
break

except GeminiTimeoutError as e:
return {
"success": False,
"error": str(e),
"all_messages": all_messages if return_all_messages else None,
}

# Validate results
if thread_id is None:
success = False
err_message = (
"Failed to get `SESSION_ID` from the gemini session. \n\n" + err_message
)
err_message = "Failed to get `SESSION_ID` from the gemini session.\n" + err_message

if success and len(agent_messages) == 0:
success = False
err_message = (
"Failed to retrieve `agent_messages` data from the Gemini session. This might be due to Gemini performing a tool call. You can continue using the `SESSION_ID` to proceed with the conversation. \n\n "
+ err_message
"Failed to retrieve `agent_messages` from the Gemini session. "
"This might be due to Gemini performing a tool call. "
"You can continue using the `SESSION_ID` to proceed.\n" + err_message
)

# Build result
if success:
result: Dict[str, Any] = {
"success": True,
"SESSION_ID": thread_id,
"agent_messages": agent_messages,
# "PROMPT": PROMPT,
}
else:
result = {"success": False, "error": err_message}
result = {"success": False, "error": err_message.strip()}

if return_all_messages:
result["all_messages"] = all_messages

Expand Down