diff --git a/src/proxmox_mcp/server.py b/src/proxmox_mcp/server.py index e2207f8..a2e04bd 100644 --- a/src/proxmox_mcp/server.py +++ b/src/proxmox_mcp/server.py @@ -36,13 +36,14 @@ # Import tool modules to register them with mcp from proxmox_mcp.prompts import prompts # noqa: E402, F401 from proxmox_mcp.resources import resources # noqa: E402, F401 -from proxmox_mcp.tools import ( # noqa: E402, F401 # noqa: E402, F401 # noqa: E402, F401 +from proxmox_mcp.tools import ( # noqa: E402, F401 backup, cluster, container, disk, network, node, + ssh_tools, storage, task, vm, diff --git a/src/proxmox_mcp/ssh.py b/src/proxmox_mcp/ssh.py index 224f82c..d78f21a 100644 --- a/src/proxmox_mcp/ssh.py +++ b/src/proxmox_mcp/ssh.py @@ -1,4 +1,4 @@ -"""SSH execution layer for running commands on Proxmox nodes.""" +"""SSH execution layer for running commands on Proxmox nodes and VMs.""" import asyncio import logging @@ -30,13 +30,29 @@ def success(self) -> bool: class SSHExecutor: - """Execute commands on Proxmox nodes over SSH.""" + """Execute commands on Proxmox nodes and VMs over SSH.""" def __init__(self, config: ProxmoxConfig) -> None: self.config = config - def _create_client(self, host: str) -> paramiko.SSHClient: - """Create and configure a paramiko SSH client.""" + def _create_client( + self, + host: str, + *, + port: int | None = None, + username: str | None = None, + password: str | None = None, + key_path: str | None = None, + ) -> paramiko.SSHClient: + """Create and configure a paramiko SSH client. + + Args: + host: Target hostname or IP. + port: SSH port override (defaults to config). + username: SSH username override (defaults to config). + password: SSH password override (defaults to config). + key_path: SSH private key path override (defaults to config). + """ client = paramiko.SSHClient() if self.config.PROXMOX_SSH_HOST_KEY_CHECKING: @@ -51,23 +67,25 @@ def _create_client(self, host: str) -> paramiko.SSHClient: logger.warning("SSH host key checking disabled — vulnerable to MITM attacks") client.set_missing_host_key_policy(paramiko.WarningPolicy()) + effective_port = port or self.config.PROXMOX_SSH_PORT + effective_user = username or self.config.PROXMOX_SSH_USER + effective_key = key_path or self.config.PROXMOX_SSH_KEY_PATH + effective_password = password or self.config.PROXMOX_SSH_PASSWORD or self.config.PROXMOX_PASSWORD + connect_kwargs: dict = { "hostname": host, - "port": self.config.PROXMOX_SSH_PORT, - "username": self.config.PROXMOX_SSH_USER, + "port": effective_port, + "username": effective_user, "timeout": 10, } - # Resolve SSH password: dedicated SSH password, or fall back to Proxmox API password - ssh_password = self.config.PROXMOX_SSH_PASSWORD or self.config.PROXMOX_PASSWORD - - if self.config.PROXMOX_SSH_KEY_PATH: - key_path = Path(self.config.PROXMOX_SSH_KEY_PATH).expanduser() - if not key_path.exists(): - raise SSHExecutionError(f"SSH key not found: {key_path}") - connect_kwargs["key_filename"] = str(key_path) - elif ssh_password: - connect_kwargs["password"] = ssh_password + if effective_key: + resolved_key = Path(effective_key).expanduser() + if not resolved_key.exists(): + raise SSHExecutionError(f"SSH key not found: {resolved_key}") + connect_kwargs["key_filename"] = str(resolved_key) + elif effective_password: + connect_kwargs["password"] = effective_password # Skip default key discovery when password is explicitly provided connect_kwargs["look_for_keys"] = False connect_kwargs["allow_agent"] = False @@ -81,10 +99,22 @@ def _create_client(self, host: str) -> paramiko.SSHClient: return client - def _execute_sync(self, host: str, command: str, timeout: int = 30) -> SSHResult: + def _execute_sync( + self, + host: str, + command: str, + timeout: int = 30, + *, + port: int | None = None, + username: str | None = None, + password: str | None = None, + key_path: str | None = None, + ) -> SSHResult: """Execute a command over SSH synchronously.""" timeout = min(timeout, MAX_SSH_TIMEOUT) - client = self._create_client(host) + client = self._create_client( + host, port=port, username=username, password=password, key_path=key_path + ) try: logger.debug("SSH %s: %s", host, command) _, stdout, stderr = client.exec_command(command, timeout=timeout) @@ -134,3 +164,54 @@ async def execute(self, node: str, command: str, timeout: int = 30) -> SSHResult result.stderr[:200], ) return result + + async def execute_on_host( + self, + host: str, + command: str, + timeout: int = 30, + *, + port: int | None = None, + username: str | None = None, + password: str | None = None, + key_path: str | None = None, + ) -> SSHResult: + """Execute a command on an arbitrary host (VM, container, or node) asynchronously. + + Unlike execute(), this connects directly to the given host/IP without + resolving through Proxmox node configuration. + + Args: + host: Target hostname or IP address. + command: Shell command to execute. + timeout: Command timeout in seconds (max 120). + port: SSH port override. + username: SSH username override. + password: SSH password override. + key_path: SSH private key path override. + """ + full_command = ( + "export PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:$PATH; " + + command + ) + + logger.info("SSH executing on %s: %s", host, command) + result = await asyncio.to_thread( + self._execute_sync, + host, + full_command, + timeout, + port=port, + username=username, + password=password, + key_path=key_path, + ) + + if result.exit_code != 0: + logger.warning( + "SSH command exited %d on %s: stderr=%s", + result.exit_code, + host, + result.stderr[:200], + ) + return result diff --git a/src/proxmox_mcp/tools/ssh_tools.py b/src/proxmox_mcp/tools/ssh_tools.py new file mode 100644 index 0000000..68669ed --- /dev/null +++ b/src/proxmox_mcp/tools/ssh_tools.py @@ -0,0 +1,760 @@ +"""SSH tools for direct command execution on Proxmox VMs and containers. + +Provides curated, high-level tools for package installation, service management, +file transfer, script execution, and system information retrieval. Supports +auto-discovery of VM/CT IP addresses via QEMU guest agent or LXC interfaces. +""" + +import base64 +import logging + +from proxmox_mcp.utils.errors import ( + InvalidParameterError, + SSHExecutionError, + format_error_response, +) +from proxmox_mcp.utils.sanitizers import ( + check_shell_injection, + validate_package_name, + validate_remote_file_path, + validate_script_interpreter, + validate_service_action, + validate_service_name, +) +from proxmox_mcp.utils.validators import validate_vmid + +logger = logging.getLogger("proxmox-mcp") + + +def get_client(): + from proxmox_mcp.server import proxmox_client + + return proxmox_client + + +def get_mcp(): + from proxmox_mcp.server import mcp + + return mcp + + +def get_ssh(): + from proxmox_mcp.server import ssh_executor + + return ssh_executor + + +mcp = get_mcp() + + +# --------------------------------------------------------------------------- +# IP auto-discovery helpers +# --------------------------------------------------------------------------- + + +async def _resolve_vm_ip(client, vmid: int, node: str) -> str: + """Auto-discover a VM's IP address via QEMU guest agent or LXC interfaces. + + Tries QEMU guest agent first, then LXC interfaces API. Filters out + loopback and link-local addresses, preferring IPv4. + + Returns: + The best available IP address for SSH connection. + + Raises: + SSHExecutionError: If no usable IP address is found. + """ + interfaces = [] + + # Try QEMU guest agent + try: + data = await client.api_call( + client.api.nodes(node).qemu(vmid).agent("network-get-interfaces").get + ) + raw = data.get("result", data) + if isinstance(raw, list): + interfaces = raw + except Exception: + pass + + # Try LXC interfaces + if not interfaces: + try: + data = await client.api_call( + client.api.nodes(node).lxc(vmid).interfaces.get + ) + if isinstance(data, list): + interfaces = data + except Exception: + pass + + if not interfaces: + raise SSHExecutionError( + f"Cannot discover IP for VMID {vmid}. " + f"Ensure the VM/CT is running and has QEMU guest agent (VMs) or network configured (CTs)." + ) + + # Extract usable IPs: prefer IPv4, skip loopback and link-local + ipv4_candidates = [] + ipv6_candidates = [] + + for iface in interfaces: + iface_name = iface.get("name", "") + if iface_name == "lo": + continue + + # QEMU agent format: ip-addresses list + for addr_info in iface.get("ip-addresses", []): + ip = addr_info.get("ip-address", "") + ip_type = addr_info.get("ip-address-type", "") + if ip_type == "ipv4" and not ip.startswith("127.") and not ip.startswith("169.254."): + ipv4_candidates.append(ip) + elif ip_type == "ipv6" and not ip.startswith("fe80:") and ip != "::1": + ipv6_candidates.append(ip) + + # LXC format: inet/inet6 fields (CIDR notation) + inet = iface.get("inet", "") + if inet: + ip = inet.split("/")[0] + if not ip.startswith("127.") and not ip.startswith("169.254."): + ipv4_candidates.append(ip) + inet6 = iface.get("inet6", "") + if inet6: + ip = inet6.split("/")[0] + if not ip.startswith("fe80:") and ip != "::1": + ipv6_candidates.append(ip) + + # LXC format: hwaddr with ip (some proxmox versions) + ip_field = iface.get("ip", "") + if ip_field: + ip = ip_field.split("/")[0] + if not ip.startswith("127.") and not ip.startswith("169.254."): + ipv4_candidates.append(ip) + + best_ip = next(iter(ipv4_candidates), None) or next(iter(ipv6_candidates), None) + if not best_ip: + raise SSHExecutionError( + f"No usable IP address found for VMID {vmid}. " + f"Interfaces found but all addresses are loopback or link-local." + ) + + logger.info("Resolved VMID %d to IP %s", vmid, best_ip) + return best_ip + + +async def _get_ssh_target( + vmid: int, + node: str | None = None, + target_ip: str | None = None, +) -> tuple[str, str]: + """Resolve the SSH target IP and node for a VMID. + + Args: + vmid: The VM/CT ID. + node: Node name override. + target_ip: Direct IP override (skips auto-discovery). + + Returns: + Tuple of (ip_address, node_name). + """ + client = get_client() + validate_vmid(vmid) + resolved_node = await client.resolve_node(vmid, node) + + if target_ip: + check_shell_injection(target_ip, "target_ip") + return target_ip, resolved_node + + ip = await _resolve_vm_ip(client, vmid, resolved_node) + return ip, resolved_node + + +def _build_ssh_overrides( + ssh_user: str | None, + ssh_password: str | None, + ssh_key_path: str | None, + ssh_port: int | None, +) -> dict: + """Build SSH credential override kwargs.""" + overrides: dict = {} + if ssh_user: + check_shell_injection(ssh_user, "ssh_user") + overrides["username"] = ssh_user + if ssh_password: + overrides["password"] = ssh_password + if ssh_key_path: + check_shell_injection(ssh_key_path, "ssh_key_path") + overrides["key_path"] = ssh_key_path + if ssh_port is not None: + overrides["port"] = ssh_port + return overrides + + +# --------------------------------------------------------------------------- +# Package manager detection +# --------------------------------------------------------------------------- + + +async def _detect_package_manager(ssh, host: str, **ssh_kwargs) -> str: + """Detect the package manager on a remote host. + + Returns one of: apt, dnf, yum, apk, zypper, pacman. + """ + managers = [ + ("apt-get", "apt"), + ("dnf", "dnf"), + ("yum", "yum"), + ("apk", "apk"), + ("zypper", "zypper"), + ("pacman", "pacman"), + ] + for binary, name in managers: + result = await ssh.execute_on_host( + host, f"command -v {binary} >/dev/null 2>&1 && echo found", timeout=10, **ssh_kwargs + ) + if result.success and "found" in result.stdout: + return name + + raise SSHExecutionError("No supported package manager found on target system.") + + +def _build_install_command(manager: str, packages: list[str]) -> str: + """Build the install command for a given package manager.""" + pkg_str = " ".join(packages) + commands = { + "apt": f"DEBIAN_FRONTEND=noninteractive apt-get update -qq && apt-get install -y -qq {pkg_str}", + "dnf": f"dnf install -y {pkg_str}", + "yum": f"yum install -y {pkg_str}", + "apk": f"apk add --no-cache {pkg_str}", + "zypper": f"zypper install -y {pkg_str}", + "pacman": f"pacman -Sy --noconfirm {pkg_str}", + } + return commands[manager] + + +def _build_remove_command(manager: str, packages: list[str]) -> str: + """Build the remove command for a given package manager.""" + pkg_str = " ".join(packages) + commands = { + "apt": f"apt-get remove -y -qq {pkg_str}", + "dnf": f"dnf remove -y {pkg_str}", + "yum": f"yum remove -y {pkg_str}", + "apk": f"apk del {pkg_str}", + "zypper": f"zypper remove -y {pkg_str}", + "pacman": f"pacman -R --noconfirm {pkg_str}", + } + return commands[manager] + + +# --------------------------------------------------------------------------- +# Tool 1: install_package +# --------------------------------------------------------------------------- + + +@mcp.tool() +async def install_package( + vmid: int, + packages: list[str], + action: str = "install", + node: str | None = None, + target_ip: str | None = None, + ssh_user: str | None = None, + ssh_password: str | None = None, + ssh_key_path: str | None = None, + ssh_port: int | None = None, +) -> dict: + """Install or remove packages on a VM or container via SSH. + + Auto-detects the package manager (apt, dnf, yum, apk, zypper, pacman). + + Args: + vmid: The VM/CT ID to connect to. + packages: List of package names to install or remove. + action: 'install' or 'remove'. + node: Node name override. Auto-detected if omitted. + target_ip: Direct IP/hostname of the VM. Skips auto-discovery if provided. + ssh_user: SSH username override (defaults to global config). + ssh_password: SSH password override (defaults to global config). + ssh_key_path: SSH private key path override (defaults to global config). + ssh_port: SSH port override (defaults to global config). + """ + try: + if not packages: + raise InvalidParameterError("At least one package name is required.") + if action not in ("install", "remove"): + raise InvalidParameterError("Action must be 'install' or 'remove'.") + + for pkg in packages: + validate_package_name(pkg) + + ip, resolved_node = await _get_ssh_target(vmid, node, target_ip) + ssh_kwargs = _build_ssh_overrides(ssh_user, ssh_password, ssh_key_path, ssh_port) + ssh = get_ssh() + + # Detect package manager + manager = await _detect_package_manager(ssh, ip, **ssh_kwargs) + logger.info( + "Detected package manager '%s' on VMID %d (%s)", manager, vmid, ip + ) + + # Build and execute command + if action == "install": + command = _build_install_command(manager, packages) + else: + command = _build_remove_command(manager, packages) + + result = await ssh.execute_on_host(ip, command, timeout=120, **ssh_kwargs) + + if not result.success: + return { + "status": "error", + "vmid": vmid, + "node": resolved_node, + "target_ip": ip, + "action": action, + "packages": packages, + "package_manager": manager, + "exit_code": result.exit_code, + "error": result.stderr[-500:] if result.stderr else "Unknown error", + } + + return { + "status": "success", + "vmid": vmid, + "node": resolved_node, + "target_ip": ip, + "action": action, + "packages": packages, + "package_manager": manager, + "output": result.stdout[-1000:] if result.stdout else "", + } + except Exception as e: + logger.error("Failed to %s packages on VMID %d: %s", action, vmid, e) + return format_error_response(e) + + +# --------------------------------------------------------------------------- +# Tool 2: manage_service +# --------------------------------------------------------------------------- + + +@mcp.tool() +async def manage_service( + vmid: int, + service: str, + action: str, + node: str | None = None, + target_ip: str | None = None, + ssh_user: str | None = None, + ssh_password: str | None = None, + ssh_key_path: str | None = None, + ssh_port: int | None = None, +) -> dict: + """Manage a systemd service on a VM or container via SSH. + + Args: + vmid: The VM/CT ID to connect to. + service: Systemd service name (e.g., 'nginx', 'docker'). + action: Service action: 'start', 'stop', 'restart', 'reload', 'enable', 'disable', 'status'. + node: Node name override. Auto-detected if omitted. + target_ip: Direct IP/hostname of the VM. Skips auto-discovery if provided. + ssh_user: SSH username override (defaults to global config). + ssh_password: SSH password override (defaults to global config). + ssh_key_path: SSH private key path override (defaults to global config). + ssh_port: SSH port override (defaults to global config). + """ + try: + validate_service_name(service) + validate_service_action(action) + + ip, resolved_node = await _get_ssh_target(vmid, node, target_ip) + ssh_kwargs = _build_ssh_overrides(ssh_user, ssh_password, ssh_key_path, ssh_port) + ssh = get_ssh() + + command = f"systemctl {action} {service}" + # For status, also get detailed info + if action == "status": + command = ( + f"systemctl is-active {service} 2>/dev/null; " + f"systemctl is-enabled {service} 2>/dev/null; " + f"systemctl status {service} --no-pager -l 2>/dev/null" + ) + + result = await ssh.execute_on_host(ip, command, timeout=30, **ssh_kwargs) + + if action == "status": + # Parse status output + lines = result.stdout.strip().splitlines() + is_active = lines[0].strip() if len(lines) > 0 else "unknown" + is_enabled = lines[1].strip() if len(lines) > 1 else "unknown" + full_status = "\n".join(lines[2:]) if len(lines) > 2 else "" + return { + "status": "success", + "vmid": vmid, + "node": resolved_node, + "target_ip": ip, + "service": service, + "is_active": is_active, + "is_enabled": is_enabled, + "details": full_status[-1000:], + } + + if not result.success: + return { + "status": "error", + "vmid": vmid, + "node": resolved_node, + "target_ip": ip, + "service": service, + "action": action, + "exit_code": result.exit_code, + "error": result.stderr[-500:] if result.stderr else "Unknown error", + } + + return { + "status": "success", + "vmid": vmid, + "node": resolved_node, + "target_ip": ip, + "service": service, + "action": action, + "output": result.stdout[-500:] if result.stdout else "", + } + except Exception as e: + logger.error("Failed to %s service '%s' on VMID %d: %s", action, service, vmid, e) + return format_error_response(e) + + +# --------------------------------------------------------------------------- +# Tool 3: transfer_file +# --------------------------------------------------------------------------- + + +@mcp.tool() +async def transfer_file( + vmid: int, + content: str, + destination: str, + permissions: str = "0644", + owner: str | None = None, + node: str | None = None, + target_ip: str | None = None, + ssh_user: str | None = None, + ssh_password: str | None = None, + ssh_key_path: str | None = None, + ssh_port: int | None = None, +) -> dict: + """Upload file content to a VM or container via SSH. + + Writes the provided text content to a file on the target system. + Uses base64 encoding for safe transfer of arbitrary content. + + Args: + vmid: The VM/CT ID to connect to. + content: The file content to write. + destination: Absolute path on the target (e.g., '/etc/nginx/nginx.conf'). + permissions: File permissions in octal (e.g., '0644', '0755'). + owner: File owner in 'user:group' format (e.g., 'www-data:www-data'). + node: Node name override. Auto-detected if omitted. + target_ip: Direct IP/hostname of the VM. Skips auto-discovery if provided. + ssh_user: SSH username override (defaults to global config). + ssh_password: SSH password override (defaults to global config). + ssh_key_path: SSH private key path override (defaults to global config). + ssh_port: SSH port override (defaults to global config). + """ + try: + validate_remote_file_path(destination) + + # Validate permissions format + if not _is_valid_permissions(permissions): + raise InvalidParameterError( + f"Permissions '{permissions}' are invalid. Use octal format like '0644' or '0755'." + ) + + if owner: + check_shell_injection(owner, "owner") + + ip, resolved_node = await _get_ssh_target(vmid, node, target_ip) + ssh_kwargs = _build_ssh_overrides(ssh_user, ssh_password, ssh_key_path, ssh_port) + ssh = get_ssh() + + # Base64 encode content for safe transfer + encoded = base64.b64encode(content.encode("utf-8")).decode("ascii") + + # Ensure parent directory exists, write file, set permissions + parent_dir = "/".join(destination.rsplit("/", 1)[:-1]) or "/" + commands = [ + f"mkdir -p {parent_dir}", + f"echo '{encoded}' | base64 -d > {destination}", + f"chmod {permissions} {destination}", + ] + if owner: + commands.append(f"chown {owner} {destination}") + + command = " && ".join(commands) + result = await ssh.execute_on_host(ip, command, timeout=30, **ssh_kwargs) + + if not result.success: + return { + "status": "error", + "vmid": vmid, + "node": resolved_node, + "target_ip": ip, + "destination": destination, + "exit_code": result.exit_code, + "error": result.stderr[-500:] if result.stderr else "Unknown error", + } + + return { + "status": "success", + "vmid": vmid, + "node": resolved_node, + "target_ip": ip, + "destination": destination, + "permissions": permissions, + "owner": owner, + "size_bytes": len(content.encode("utf-8")), + } + except Exception as e: + logger.error("Failed to transfer file to VMID %d: %s", vmid, e) + return format_error_response(e) + + +def _is_valid_permissions(perms: str) -> bool: + """Check if a string is a valid octal permission (e.g., '0644', '755').""" + if len(perms) not in (3, 4): + return False + return all(c in "01234567" for c in perms) + + +# --------------------------------------------------------------------------- +# Tool 4: execute_script +# --------------------------------------------------------------------------- + + +@mcp.tool() +async def execute_script( + vmid: int, + script: str, + interpreter: str = "bash", + timeout: int = 60, + node: str | None = None, + target_ip: str | None = None, + ssh_user: str | None = None, + ssh_password: str | None = None, + ssh_key_path: str | None = None, + ssh_port: int | None = None, +) -> dict: + """Upload and execute a script on a VM or container via SSH. + + The script content is base64-encoded, transferred, executed with the + specified interpreter, and then cleaned up. + + Args: + vmid: The VM/CT ID to connect to. + script: The script content to execute. + interpreter: Script interpreter: 'bash', 'sh', 'python3', 'python', 'perl'. + timeout: Execution timeout in seconds (max 120). + node: Node name override. Auto-detected if omitted. + target_ip: Direct IP/hostname of the VM. Skips auto-discovery if provided. + ssh_user: SSH username override (defaults to global config). + ssh_password: SSH password override (defaults to global config). + ssh_key_path: SSH private key path override (defaults to global config). + ssh_port: SSH port override (defaults to global config). + """ + try: + if not script.strip(): + raise InvalidParameterError("Script content cannot be empty.") + + validate_script_interpreter(interpreter) + timeout = min(max(timeout, 5), 120) + + ip, resolved_node = await _get_ssh_target(vmid, node, target_ip) + ssh_kwargs = _build_ssh_overrides(ssh_user, ssh_password, ssh_key_path, ssh_port) + ssh = get_ssh() + + # Base64 encode and transfer via pipe to interpreter + encoded = base64.b64encode(script.encode("utf-8")).decode("ascii") + command = f"echo '{encoded}' | base64 -d | {interpreter}" + + result = await ssh.execute_on_host(ip, command, timeout=timeout, **ssh_kwargs) + + return { + "status": "success" if result.success else "error", + "vmid": vmid, + "node": resolved_node, + "target_ip": ip, + "interpreter": interpreter, + "exit_code": result.exit_code, + "stdout": result.stdout[-2000:] if result.stdout else "", + "stderr": result.stderr[-1000:] if result.stderr else "", + } + except Exception as e: + logger.error("Failed to execute script on VMID %d: %s", vmid, e) + return format_error_response(e) + + +# --------------------------------------------------------------------------- +# Tool 5: get_system_info +# --------------------------------------------------------------------------- + + +@mcp.tool() +async def get_system_info( + vmid: int, + node: str | None = None, + target_ip: str | None = None, + ssh_user: str | None = None, + ssh_password: str | None = None, + ssh_key_path: str | None = None, + ssh_port: int | None = None, +) -> dict: + """Get system information from a VM or container via SSH. + + Retrieves hostname, OS, kernel, uptime, CPU, memory, and disk usage. + + Args: + vmid: The VM/CT ID to connect to. + node: Node name override. Auto-detected if omitted. + target_ip: Direct IP/hostname of the VM. Skips auto-discovery if provided. + ssh_user: SSH username override (defaults to global config). + ssh_password: SSH password override (defaults to global config). + ssh_key_path: SSH private key path override (defaults to global config). + ssh_port: SSH port override (defaults to global config). + """ + try: + ip, resolved_node = await _get_ssh_target(vmid, node, target_ip) + ssh_kwargs = _build_ssh_overrides(ssh_user, ssh_password, ssh_key_path, ssh_port) + ssh = get_ssh() + + # Gather system info in a single SSH call for efficiency + info_script = ( + "echo '---HOSTNAME---'; hostname; " + "echo '---OS---'; cat /etc/os-release 2>/dev/null || echo 'unknown'; " + "echo '---KERNEL---'; uname -r; " + "echo '---ARCH---'; uname -m; " + "echo '---UPTIME---'; uptime -p 2>/dev/null || uptime; " + "echo '---CPU---'; nproc 2>/dev/null || grep -c processor /proc/cpuinfo; " + "echo '---MEMORY---'; free -b 2>/dev/null | grep Mem; " + "echo '---DISK---'; df -B1 / 2>/dev/null | tail -1; " + "echo '---LOAD---'; cat /proc/loadavg 2>/dev/null; " + "echo '---END---'" + ) + + result = await ssh.execute_on_host(ip, info_script, timeout=15, **ssh_kwargs) + + if not result.success: + return { + "status": "error", + "vmid": vmid, + "node": resolved_node, + "target_ip": ip, + "exit_code": result.exit_code, + "error": result.stderr[-500:] if result.stderr else "Unknown error", + } + + info = _parse_system_info(result.stdout) + return { + "status": "success", + "vmid": vmid, + "node": resolved_node, + "target_ip": ip, + **info, + } + except Exception as e: + logger.error("Failed to get system info from VMID %d: %s", vmid, e) + return format_error_response(e) + + +def _parse_system_info(output: str) -> dict: + """Parse the combined system info output into structured data.""" + sections: dict[str, list[str]] = {} + current_section = "" + + for line in output.splitlines(): + stripped = line.strip() + if stripped.startswith("---") and stripped.endswith("---"): + current_section = stripped.strip("-") + sections[current_section] = [] + elif current_section and current_section != "END": + sections.setdefault(current_section, []).append(stripped) + + info: dict = {} + + # Hostname + info["hostname"] = sections.get("HOSTNAME", ["unknown"])[0] + + # OS + os_lines = sections.get("OS", []) + os_info = {} + for line in os_lines: + if "=" in line: + key, _, val = line.partition("=") + os_info[key] = val.strip('"') + info["os"] = os_info.get("PRETTY_NAME", os_info.get("NAME", "unknown")) + info["os_id"] = os_info.get("ID", "unknown") + + # Kernel and arch + info["kernel"] = sections.get("KERNEL", ["unknown"])[0] + info["architecture"] = sections.get("ARCH", ["unknown"])[0] + + # Uptime + info["uptime"] = sections.get("UPTIME", ["unknown"])[0] + + # CPU count + cpu_lines = sections.get("CPU", ["0"]) + try: + info["cpu_count"] = int(cpu_lines[0]) + except ValueError: + info["cpu_count"] = 0 + + # Memory + mem_lines = sections.get("MEMORY", []) + if mem_lines: + parts = mem_lines[0].split() + # free -b output: Mem: total used free shared buff/cache available + if len(parts) >= 4: + try: + info["memory"] = { + "total_bytes": int(parts[1]), + "used_bytes": int(parts[2]), + "free_bytes": int(parts[3]), + "available_bytes": int(parts[6]) if len(parts) > 6 else None, + } + except (ValueError, IndexError): + info["memory"] = {"raw": mem_lines[0]} + else: + info["memory"] = {} + + # Disk + disk_lines = sections.get("DISK", []) + if disk_lines: + parts = disk_lines[0].split() + if len(parts) >= 5: + try: + info["root_disk"] = { + "filesystem": parts[0], + "total_bytes": int(parts[1]), + "used_bytes": int(parts[2]), + "available_bytes": int(parts[3]), + "use_percent": parts[4], + } + except (ValueError, IndexError): + info["root_disk"] = {"raw": disk_lines[0]} + else: + info["root_disk"] = {} + + # Load average + load_lines = sections.get("LOAD", []) + if load_lines: + parts = load_lines[0].split() + if len(parts) >= 3: + info["load_average"] = { + "1min": parts[0], + "5min": parts[1], + "15min": parts[2], + } + else: + info["load_average"] = {} + + return info diff --git a/src/proxmox_mcp/utils/sanitizers.py b/src/proxmox_mcp/utils/sanitizers.py index e3ef7f3..64e1074 100644 --- a/src/proxmox_mcp/utils/sanitizers.py +++ b/src/proxmox_mcp/utils/sanitizers.py @@ -91,6 +91,14 @@ VALID_FILESYSTEMS = frozenset({"ext4", "xfs", "vfat"}) VALID_PARTITION_TABLES = frozenset({"gpt", "msdos"}) +# SSH tool validation patterns +PACKAGE_NAME_RE = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9_.+\-]{0,127}$") +SERVICE_NAME_RE = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9_.\-@]{0,127}$") +REMOTE_FILE_PATH_RE = re.compile(r"^/[a-zA-Z0-9_./\-]+$") + +VALID_SERVICE_ACTIONS = frozenset({"start", "stop", "restart", "reload", "enable", "disable", "status"}) +VALID_SCRIPT_INTERPRETERS = frozenset({"bash", "sh", "python3", "python", "perl"}) + def check_shell_injection(value: str, param_name: str) -> None: """Reject any value containing shell metacharacters.""" @@ -235,3 +243,54 @@ def validate_uuid(uuid_str: str) -> str: f"Expected format: xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" ) return uuid_str + + +def validate_package_name(name: str) -> None: + """Validate a system package name.""" + check_shell_injection(name, "package") + if not PACKAGE_NAME_RE.match(name): + raise InvalidParameterError( + f"Package name '{name}' is invalid. " + f"Must start with alphanumeric and contain only [a-zA-Z0-9_.+-]." + ) + + +def validate_service_name(name: str) -> None: + """Validate a systemd service name.""" + check_shell_injection(name, "service") + if not SERVICE_NAME_RE.match(name): + raise InvalidParameterError( + f"Service name '{name}' is invalid. " + f"Must start with alphanumeric and contain only [a-zA-Z0-9_.\\-@]." + ) + + +def validate_service_action(action: str) -> None: + """Validate a systemd service action.""" + if action not in VALID_SERVICE_ACTIONS: + raise InvalidParameterError( + f"Service action '{action}' is invalid. " + f"Allowed: {', '.join(sorted(VALID_SERVICE_ACTIONS))}" + ) + + +def validate_remote_file_path(path: str) -> None: + """Validate a remote file path for file transfer.""" + check_shell_injection(path, "file_path") + if ".." in path: + raise InvalidParameterError( + f"File path '{path}' contains path traversal (..) which is not allowed." + ) + if not REMOTE_FILE_PATH_RE.match(path): + raise InvalidParameterError( + f"File path '{path}' is invalid. Must be an absolute path with safe characters." + ) + + +def validate_script_interpreter(interpreter: str) -> None: + """Validate a script interpreter name.""" + if interpreter not in VALID_SCRIPT_INTERPRETERS: + raise InvalidParameterError( + f"Interpreter '{interpreter}' is not allowed. " + f"Allowed: {', '.join(sorted(VALID_SCRIPT_INTERPRETERS))}" + ) diff --git a/tests/test_ssh_tools.py b/tests/test_ssh_tools.py new file mode 100644 index 0000000..8325f6c --- /dev/null +++ b/tests/test_ssh_tools.py @@ -0,0 +1,660 @@ +"""Tests for SSH tools (install_package, manage_service, transfer_file, execute_script, get_system_info).""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from proxmox_mcp.ssh import SSHResult + + +@pytest.fixture +def mock_client(): + """Create a mock ProxmoxClient.""" + client = MagicMock() + client.api_call = AsyncMock() + client.validate_node = MagicMock() + client.resolve_node = AsyncMock(return_value="pve1") + client.is_dry_run = False + return client + + +@pytest.fixture +def mock_ssh(): + """Create a mock SSHExecutor.""" + ssh = MagicMock() + ssh.execute = AsyncMock() + ssh.execute_on_host = AsyncMock() + return ssh + + +def _patch_tools(mock_client, mock_ssh, mock_interfaces=None): + """Return a context manager that patches client, ssh, and optionally IP discovery.""" + from contextlib import contextmanager + + @contextmanager + def _ctx(): + # Default: QEMU agent returns an eth0 with 10.0.0.100 + if mock_interfaces is None: + ifaces = { + "result": [ + { + "name": "lo", + "ip-addresses": [ + {"ip-address": "127.0.0.1", "ip-address-type": "ipv4"} + ], + }, + { + "name": "eth0", + "ip-addresses": [ + {"ip-address": "10.0.0.100", "ip-address-type": "ipv4"} + ], + }, + ] + } + else: + ifaces = mock_interfaces + + mock_client.api_call.return_value = ifaces + + with ( + patch("proxmox_mcp.tools.ssh_tools.get_client", return_value=mock_client), + patch("proxmox_mcp.tools.ssh_tools.get_ssh", return_value=mock_ssh), + ): + yield + + return _ctx() + + +# --------------------------------------------------------------------------- +# IP Auto-Discovery +# --------------------------------------------------------------------------- + + +class TestResolveVmIp: + async def test_discovers_ipv4_from_qemu_agent(self, mock_client): + from proxmox_mcp.tools.ssh_tools import _resolve_vm_ip + + mock_client.api_call.return_value = { + "result": [ + { + "name": "lo", + "ip-addresses": [ + {"ip-address": "127.0.0.1", "ip-address-type": "ipv4"} + ], + }, + { + "name": "eth0", + "ip-addresses": [ + {"ip-address": "10.0.0.50", "ip-address-type": "ipv4"}, + {"ip-address": "fe80::1", "ip-address-type": "ipv6"}, + ], + }, + ] + } + + ip = await _resolve_vm_ip(mock_client, 100, "pve1") + assert ip == "10.0.0.50" + + async def test_skips_loopback_and_link_local(self, mock_client): + from proxmox_mcp.tools.ssh_tools import _resolve_vm_ip + + mock_client.api_call.return_value = { + "result": [ + { + "name": "lo", + "ip-addresses": [ + {"ip-address": "127.0.0.1", "ip-address-type": "ipv4"} + ], + }, + { + "name": "eth0", + "ip-addresses": [ + {"ip-address": "169.254.1.1", "ip-address-type": "ipv4"}, + {"ip-address": "192.168.1.10", "ip-address-type": "ipv4"}, + ], + }, + ] + } + + ip = await _resolve_vm_ip(mock_client, 100, "pve1") + assert ip == "192.168.1.10" + + async def test_falls_back_to_lxc_interfaces(self, mock_client): + from proxmox_mcp.tools.ssh_tools import _resolve_vm_ip + + # QEMU agent fails + call_count = 0 + + async def side_effect(func): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise Exception("QEMU agent not available") + return [{"name": "eth0", "inet": "172.16.0.5/24"}] + + mock_client.api_call = side_effect + + ip = await _resolve_vm_ip(mock_client, 200, "pve1") + assert ip == "172.16.0.5" + + async def test_raises_when_no_interfaces(self, mock_client): + from proxmox_mcp.ssh import SSHExecutionError + from proxmox_mcp.tools.ssh_tools import _resolve_vm_ip + + mock_client.api_call = AsyncMock(side_effect=Exception("not found")) + + with pytest.raises(SSHExecutionError, match="Cannot discover IP"): + await _resolve_vm_ip(mock_client, 999, "pve1") + + +# --------------------------------------------------------------------------- +# install_package +# --------------------------------------------------------------------------- + + +class TestInstallPackage: + async def test_install_success_apt(self, mock_client, mock_ssh): + from proxmox_mcp.tools.ssh_tools import install_package + + # First call: detect package manager + mock_ssh.execute_on_host.side_effect = [ + SSHResult(exit_code=0, stdout="found\n", stderr=""), # apt-get detected + SSHResult(exit_code=0, stdout="Reading package lists...\nDone\n", stderr=""), # install + ] + + with _patch_tools(mock_client, mock_ssh): + result = await install_package(vmid=100, packages=["nginx", "curl"]) + + assert result["status"] == "success" + assert result["packages"] == ["nginx", "curl"] + assert result["package_manager"] == "apt" + assert result["vmid"] == 100 + + async def test_install_empty_packages(self, mock_client, mock_ssh): + from proxmox_mcp.tools.ssh_tools import install_package + + with _patch_tools(mock_client, mock_ssh): + result = await install_package(vmid=100, packages=[]) + + assert result["status"] == "error" + assert "At least one package" in result["message"] + + async def test_install_invalid_package_name(self, mock_client, mock_ssh): + from proxmox_mcp.tools.ssh_tools import install_package + + with _patch_tools(mock_client, mock_ssh): + result = await install_package(vmid=100, packages=["valid", "bad;rm -rf /"]) + + assert result["status"] == "error" + assert "forbidden characters" in result["message"] + + async def test_remove_packages(self, mock_client, mock_ssh): + from proxmox_mcp.tools.ssh_tools import install_package + + mock_ssh.execute_on_host.side_effect = [ + SSHResult(exit_code=0, stdout="found\n", stderr=""), # apt-get detected + SSHResult(exit_code=0, stdout="Removing nginx...\n", stderr=""), # remove + ] + + with _patch_tools(mock_client, mock_ssh): + result = await install_package(vmid=100, packages=["nginx"], action="remove") + + assert result["status"] == "success" + assert result["action"] == "remove" + + async def test_install_with_target_ip(self, mock_client, mock_ssh): + from proxmox_mcp.tools.ssh_tools import install_package + + mock_ssh.execute_on_host.side_effect = [ + SSHResult(exit_code=0, stdout="found\n", stderr=""), + SSHResult(exit_code=0, stdout="Done\n", stderr=""), + ] + + with _patch_tools(mock_client, mock_ssh): + result = await install_package( + vmid=100, packages=["git"], target_ip="192.168.1.50" + ) + + assert result["status"] == "success" + assert result["target_ip"] == "192.168.1.50" + + async def test_install_with_ssh_overrides(self, mock_client, mock_ssh): + from proxmox_mcp.tools.ssh_tools import install_package + + mock_ssh.execute_on_host.side_effect = [ + SSHResult(exit_code=0, stdout="found\n", stderr=""), + SSHResult(exit_code=0, stdout="Done\n", stderr=""), + ] + + with _patch_tools(mock_client, mock_ssh): + result = await install_package( + vmid=100, + packages=["htop"], + ssh_user="admin", + ssh_port=2222, + ) + + assert result["status"] == "success" + # Verify SSH overrides were passed + call_kwargs = mock_ssh.execute_on_host.call_args_list[0] + assert call_kwargs.kwargs.get("username") == "admin" + assert call_kwargs.kwargs.get("port") == 2222 + + async def test_install_detects_dnf(self, mock_client, mock_ssh): + from proxmox_mcp.tools.ssh_tools import install_package + + mock_ssh.execute_on_host.side_effect = [ + SSHResult(exit_code=1, stdout="", stderr=""), # apt-get not found + SSHResult(exit_code=0, stdout="found\n", stderr=""), # dnf found + SSHResult(exit_code=0, stdout="Complete!\n", stderr=""), # install + ] + + with _patch_tools(mock_client, mock_ssh): + result = await install_package(vmid=100, packages=["httpd"]) + + assert result["status"] == "success" + assert result["package_manager"] == "dnf" + + async def test_install_failure(self, mock_client, mock_ssh): + from proxmox_mcp.tools.ssh_tools import install_package + + mock_ssh.execute_on_host.side_effect = [ + SSHResult(exit_code=0, stdout="found\n", stderr=""), # apt detected + SSHResult(exit_code=100, stdout="", stderr="E: Unable to locate package foobar"), + ] + + with _patch_tools(mock_client, mock_ssh): + result = await install_package(vmid=100, packages=["foobar"]) + + assert result["status"] == "error" + assert result["exit_code"] == 100 + + +# --------------------------------------------------------------------------- +# manage_service +# --------------------------------------------------------------------------- + + +class TestManageService: + async def test_start_service(self, mock_client, mock_ssh): + from proxmox_mcp.tools.ssh_tools import manage_service + + mock_ssh.execute_on_host.return_value = SSHResult( + exit_code=0, stdout="", stderr="" + ) + + with _patch_tools(mock_client, mock_ssh): + result = await manage_service(vmid=100, service="nginx", action="start") + + assert result["status"] == "success" + assert result["action"] == "start" + assert result["service"] == "nginx" + + async def test_service_status(self, mock_client, mock_ssh): + from proxmox_mcp.tools.ssh_tools import manage_service + + mock_ssh.execute_on_host.return_value = SSHResult( + exit_code=0, + stdout="active\nenabled\nā— nginx.service - A high performance web server\n", + stderr="", + ) + + with _patch_tools(mock_client, mock_ssh): + result = await manage_service(vmid=100, service="nginx", action="status") + + assert result["status"] == "success" + assert result["is_active"] == "active" + assert result["is_enabled"] == "enabled" + + async def test_invalid_action(self, mock_client, mock_ssh): + from proxmox_mcp.tools.ssh_tools import manage_service + + with _patch_tools(mock_client, mock_ssh): + result = await manage_service(vmid=100, service="nginx", action="destroy") + + assert result["status"] == "error" + assert "invalid" in result["message"].lower() + + async def test_invalid_service_name(self, mock_client, mock_ssh): + from proxmox_mcp.tools.ssh_tools import manage_service + + with _patch_tools(mock_client, mock_ssh): + result = await manage_service(vmid=100, service="bad;service", action="start") + + assert result["status"] == "error" + assert "forbidden characters" in result["message"] + + async def test_restart_failure(self, mock_client, mock_ssh): + from proxmox_mcp.tools.ssh_tools import manage_service + + mock_ssh.execute_on_host.return_value = SSHResult( + exit_code=1, stdout="", stderr="Failed to restart nginx.service: Unit not found." + ) + + with _patch_tools(mock_client, mock_ssh): + result = await manage_service(vmid=100, service="nginx", action="restart") + + assert result["status"] == "error" + assert result["exit_code"] == 1 + + +# --------------------------------------------------------------------------- +# transfer_file +# --------------------------------------------------------------------------- + + +class TestTransferFile: + async def test_transfer_success(self, mock_client, mock_ssh): + from proxmox_mcp.tools.ssh_tools import transfer_file + + mock_ssh.execute_on_host.return_value = SSHResult( + exit_code=0, stdout="", stderr="" + ) + + with _patch_tools(mock_client, mock_ssh): + result = await transfer_file( + vmid=100, + content="server { listen 80; }", + destination="/etc/nginx/sites-available/default", + ) + + assert result["status"] == "success" + assert result["destination"] == "/etc/nginx/sites-available/default" + assert result["permissions"] == "0644" + assert result["size_bytes"] == len("server { listen 80; }".encode()) + + async def test_transfer_with_owner(self, mock_client, mock_ssh): + from proxmox_mcp.tools.ssh_tools import transfer_file + + mock_ssh.execute_on_host.return_value = SSHResult( + exit_code=0, stdout="", stderr="" + ) + + with _patch_tools(mock_client, mock_ssh): + result = await transfer_file( + vmid=100, + content="hello", + destination="/var/www/index.html", + permissions="0755", + owner="www-data:www-data", + ) + + assert result["status"] == "success" + assert result["owner"] == "www-data:www-data" + # Verify chown is in the command + cmd = mock_ssh.execute_on_host.call_args[0][1] + assert "chown www-data:www-data" in cmd + + async def test_transfer_invalid_path(self, mock_client, mock_ssh): + from proxmox_mcp.tools.ssh_tools import transfer_file + + with _patch_tools(mock_client, mock_ssh): + result = await transfer_file( + vmid=100, + content="data", + destination="relative/path", + ) + + assert result["status"] == "error" + assert "invalid" in result["message"].lower() + + async def test_transfer_path_traversal(self, mock_client, mock_ssh): + from proxmox_mcp.tools.ssh_tools import transfer_file + + with _patch_tools(mock_client, mock_ssh): + result = await transfer_file( + vmid=100, + content="data", + destination="/etc/../../../tmp/evil", + ) + + assert result["status"] == "error" + assert "traversal" in result["message"].lower() + + async def test_transfer_invalid_permissions(self, mock_client, mock_ssh): + from proxmox_mcp.tools.ssh_tools import transfer_file + + with _patch_tools(mock_client, mock_ssh): + result = await transfer_file( + vmid=100, + content="data", + destination="/tmp/test", + permissions="9999", + ) + + assert result["status"] == "error" + assert "invalid" in result["message"].lower() + + +# --------------------------------------------------------------------------- +# execute_script +# --------------------------------------------------------------------------- + + +class TestExecuteScript: + async def test_script_success(self, mock_client, mock_ssh): + from proxmox_mcp.tools.ssh_tools import execute_script + + mock_ssh.execute_on_host.return_value = SSHResult( + exit_code=0, stdout="Hello World\n", stderr="" + ) + + with _patch_tools(mock_client, mock_ssh): + result = await execute_script( + vmid=100, + script='echo "Hello World"', + ) + + assert result["status"] == "success" + assert result["exit_code"] == 0 + assert "Hello World" in result["stdout"] + assert result["interpreter"] == "bash" + + async def test_script_python(self, mock_client, mock_ssh): + from proxmox_mcp.tools.ssh_tools import execute_script + + mock_ssh.execute_on_host.return_value = SSHResult( + exit_code=0, stdout="42\n", stderr="" + ) + + with _patch_tools(mock_client, mock_ssh): + result = await execute_script( + vmid=100, + script="print(6 * 7)", + interpreter="python3", + ) + + assert result["status"] == "success" + assert result["interpreter"] == "python3" + + async def test_script_invalid_interpreter(self, mock_client, mock_ssh): + from proxmox_mcp.tools.ssh_tools import execute_script + + with _patch_tools(mock_client, mock_ssh): + result = await execute_script( + vmid=100, + script="echo hi", + interpreter="ruby", + ) + + assert result["status"] == "error" + assert "not allowed" in result["message"] + + async def test_script_empty(self, mock_client, mock_ssh): + from proxmox_mcp.tools.ssh_tools import execute_script + + with _patch_tools(mock_client, mock_ssh): + result = await execute_script(vmid=100, script=" ") + + assert result["status"] == "error" + assert "empty" in result["message"].lower() + + async def test_script_failure(self, mock_client, mock_ssh): + from proxmox_mcp.tools.ssh_tools import execute_script + + mock_ssh.execute_on_host.return_value = SSHResult( + exit_code=1, stdout="", stderr="command not found" + ) + + with _patch_tools(mock_client, mock_ssh): + result = await execute_script(vmid=100, script="nonexistent_command") + + assert result["status"] == "error" + assert result["exit_code"] == 1 + + async def test_script_timeout_clamped(self, mock_client, mock_ssh): + from proxmox_mcp.tools.ssh_tools import execute_script + + mock_ssh.execute_on_host.return_value = SSHResult( + exit_code=0, stdout="ok\n", stderr="" + ) + + with _patch_tools(mock_client, mock_ssh): + result = await execute_script(vmid=100, script="echo ok", timeout=999) + + assert result["status"] == "success" + # Verify timeout was clamped to 120 + call_kwargs = mock_ssh.execute_on_host.call_args + assert call_kwargs.kwargs.get("timeout", call_kwargs[0][2] if len(call_kwargs[0]) > 2 else None) == 120 + + +# --------------------------------------------------------------------------- +# get_system_info +# --------------------------------------------------------------------------- + + +class TestGetSystemInfo: + async def test_system_info_success(self, mock_client, mock_ssh): + from proxmox_mcp.tools.ssh_tools import get_system_info + + mock_ssh.execute_on_host.return_value = SSHResult( + exit_code=0, + stdout=( + "---HOSTNAME---\n" + "web-server-01\n" + "---OS---\n" + 'PRETTY_NAME="Debian GNU/Linux 12 (bookworm)"\n' + 'NAME="Debian GNU/Linux"\n' + 'ID=debian\n' + "---KERNEL---\n" + "6.1.0-18-amd64\n" + "---ARCH---\n" + "x86_64\n" + "---UPTIME---\n" + "up 5 days, 3 hours, 22 minutes\n" + "---CPU---\n" + "4\n" + "---MEMORY---\n" + "Mem: 8349024256 2147483648 4194304000 134217728 2007236608 6000000000\n" + "---DISK---\n" + "/dev/sda1 50000000000 15000000000 33000000000 32% /\n" + "---LOAD---\n" + "0.15 0.10 0.05 1/234 5678\n" + "---END---\n" + ), + stderr="", + ) + + with _patch_tools(mock_client, mock_ssh): + result = await get_system_info(vmid=100) + + assert result["status"] == "success" + assert result["hostname"] == "web-server-01" + assert "Debian" in result["os"] + assert result["os_id"] == "debian" + assert result["kernel"] == "6.1.0-18-amd64" + assert result["architecture"] == "x86_64" + assert result["cpu_count"] == 4 + assert result["memory"]["total_bytes"] == 8349024256 + assert result["root_disk"]["use_percent"] == "32%" + assert result["load_average"]["1min"] == "0.15" + + async def test_system_info_failure(self, mock_client, mock_ssh): + from proxmox_mcp.tools.ssh_tools import get_system_info + + mock_ssh.execute_on_host.return_value = SSHResult( + exit_code=255, stdout="", stderr="Connection refused" + ) + + with _patch_tools(mock_client, mock_ssh): + result = await get_system_info(vmid=100) + + assert result["status"] == "error" + assert "Connection refused" in result["error"] + + +# --------------------------------------------------------------------------- +# Sanitizer tests for new validators +# --------------------------------------------------------------------------- + + +class TestSSHSanitizers: + def test_valid_package_names(self): + from proxmox_mcp.utils.sanitizers import validate_package_name + + for name in ["nginx", "python3.11", "libssl-dev", "gcc-12", "dotnet-sdk-8.0"]: + validate_package_name(name) # Should not raise + + def test_invalid_package_names(self): + from proxmox_mcp.utils.errors import InvalidParameterError + from proxmox_mcp.utils.sanitizers import validate_package_name + + for name in ["bad;pkg", "$(evil)", "pkg name", ""]: + with pytest.raises(InvalidParameterError): + validate_package_name(name) + + def test_valid_service_names(self): + from proxmox_mcp.utils.sanitizers import validate_service_name + + for name in ["nginx", "docker.service", "getty@tty1", "ssh-agent"]: + validate_service_name(name) + + def test_invalid_service_names(self): + from proxmox_mcp.utils.errors import InvalidParameterError + from proxmox_mcp.utils.sanitizers import validate_service_name + + for name in ["bad;svc", "$(whoami)", ""]: + with pytest.raises(InvalidParameterError): + validate_service_name(name) + + def test_valid_service_actions(self): + from proxmox_mcp.utils.sanitizers import validate_service_action + + for action in ["start", "stop", "restart", "reload", "enable", "disable", "status"]: + validate_service_action(action) + + def test_invalid_service_actions(self): + from proxmox_mcp.utils.errors import InvalidParameterError + from proxmox_mcp.utils.sanitizers import validate_service_action + + for action in ["destroy", "kill", "exec"]: + with pytest.raises(InvalidParameterError): + validate_service_action(action) + + def test_valid_remote_paths(self): + from proxmox_mcp.utils.sanitizers import validate_remote_file_path + + for path in ["/etc/nginx/nginx.conf", "/tmp/test", "/var/www/html/index.html"]: + validate_remote_file_path(path) + + def test_invalid_remote_paths(self): + from proxmox_mcp.utils.errors import InvalidParameterError + from proxmox_mcp.utils.sanitizers import validate_remote_file_path + + for path in ["relative", "/etc/../secret", "/path with spaces"]: + with pytest.raises(InvalidParameterError): + validate_remote_file_path(path) + + def test_valid_interpreters(self): + from proxmox_mcp.utils.sanitizers import validate_script_interpreter + + for interp in ["bash", "sh", "python3", "python", "perl"]: + validate_script_interpreter(interp) + + def test_invalid_interpreters(self): + from proxmox_mcp.utils.errors import InvalidParameterError + from proxmox_mcp.utils.sanitizers import validate_script_interpreter + + for interp in ["ruby", "node", "php"]: + with pytest.raises(InvalidParameterError): + validate_script_interpreter(interp)