diff --git a/docs/api/authentication.md b/docs/api/authentication.md new file mode 100644 index 000000000..35b580536 --- /dev/null +++ b/docs/api/authentication.md @@ -0,0 +1,618 @@ +# Zaparoo Core Secure API Authentication + +This document provides a complete guide for developers implementing clients that connect to Zaparoo Core using the secure authentication layer. + +## Table of Contents + +- [Overview](#overview) +- [Security Model](#security-model) +- [Client Pairing](#client-pairing) +- [HTTP API Usage](#http-api-usage) +- [WebSocket API Usage](#websocket-api-usage) +- [Example Implementations](#example-implementations) +- [Error Handling](#error-handling) +- [Security Best Practices](#security-best-practices) + +## Overview + +Zaparoo Core implements a secure authentication system for remote client connections while maintaining backward compatibility for localhost connections. The security model uses: + +- **AES-256-GCM** encryption for all remote API communications +- **Argon2id + HKDF** key derivation for client pairing +- **Sequence numbers + nonces** for replay attack protection +- **Per-client state management** for concurrent access safety + +### Connection Types + +| Connection Type | Authentication Required | Encryption Required | +| -------------------------- | ----------------------- | ------------------- | +| Localhost (127.0.0.1, ::1) | ❌ No | ❌ No | +| Remote (all other IPs) | ✅ Yes | ✅ Yes | + +## Security Model + +### Encryption Format + +All remote API requests must use the following encrypted format: + +```json +{ + "encrypted": "base64-encoded-ciphertext", + "iv": "base64-encoded-initialization-vector", + "authToken": "client-auth-token" +} +``` + +### Decrypted Payload Format + +The decrypted payload contains the standard JSON-RPC request plus security fields: + +```json +{ + "jsonrpc": "2.0", + "method": "system.version", + "id": 1, + "params": {}, + "seq": 123, + "nonce": "unique-request-nonce" +} +``` + +## Client Pairing + +Before making authenticated requests, clients must complete a pairing process to obtain shared encryption keys. + +### Step 1: Initiate Pairing + +**Request:** + +```http +POST /api/pair/initiate +Content-Type: application/json + +{ + "clientName": "MyApp_v1.0" +} +``` + +**Response:** + +```json +{ + "pairingToken": "550e8400-e29b-41d4-a716-446655440000", + "expiresIn": 300 +} +``` + +### Step 2: Complete Pairing + +The client must obtain a verification code through an out-of-band method (QR code, manual entry, etc.) and complete pairing. + +**Request:** + +```http +POST /api/pair/complete +Content-Type: application/json + +{ + "pairingToken": "550e8400-e29b-41d4-a716-446655440000", + "verifier": "user-provided-verification-code", + "clientName": "MyApp_v1.0" +} +``` + +**Response:** + +```json +{ + "clientId": "client-uuid-here", + "authToken": "auth-token-uuid-here", + "sharedSecret": "64-character-hex-encoded-key" +} +``` + +### Client Name Requirements + +- Length: 1-100 characters +- Characters: letters, numbers, underscore, dash only (`^[a-zA-Z0-9_-]+$`) +- Must be unique per client instance + +## HTTP API Usage + +### Making Authenticated Requests + +1. **Prepare the JSON-RPC payload** with sequence number and nonce +2. **Encrypt the payload** using AES-256-GCM +3. **Send the encrypted request** to the API endpoint + +### Example: Get System Version + +```javascript +// 1. Prepare JSON-RPC payload +const payload = { + jsonrpc: "2.0", + method: "system.version", + id: 1, + seq: ++sequenceNumber, + nonce: generateNonce(), +}; + +// 2. Encrypt payload +const encrypted = await encryptPayload(JSON.stringify(payload), sharedSecret); + +// 3. Send request +const response = await fetch("http://zaparoo-host:7497/api/v0.1", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + encrypted: encrypted.ciphertext, + iv: encrypted.iv, + authToken: authToken, + }), +}); + +// 4. Decrypt response +const encryptedResponse = await response.json(); +const decryptedResponse = await decryptPayload( + encryptedResponse.encrypted, + encryptedResponse.iv, + sharedSecret +); +``` + +## WebSocket API Usage + +WebSocket connections require a two-step authentication process: + +### Step 1: Connect and Authenticate + +```javascript +const ws = new WebSocket("ws://zaparoo-host:7497/api/v0.1"); + +ws.onopen = () => { + // Send authentication message + ws.send( + JSON.stringify({ + authToken: authToken, + }) + ); +}; + +ws.onmessage = (event) => { + const message = JSON.parse(event.data); + + if (message.authenticated) { + console.log("WebSocket authenticated successfully"); + // Now ready to send encrypted requests + } else { + // Handle encrypted response + handleEncryptedMessage(message); + } +}; +``` + +### Step 2: Send Encrypted Requests + +```javascript +async function sendRequest(method, params = {}) { + const payload = { + jsonrpc: "2.0", + method: method, + id: generateId(), + params: params, + seq: ++sequenceNumber, + nonce: generateNonce(), + }; + + const encrypted = await encryptPayload(JSON.stringify(payload), sharedSecret); + + ws.send( + JSON.stringify({ + encrypted: encrypted.ciphertext, + iv: encrypted.iv, + authToken: authToken, + }) + ); +} + +// Example usage +await sendRequest("system.version"); +await sendRequest("search.games", { query: "mario" }); +``` + +### Handling Encrypted Responses + +```javascript +async function handleEncryptedMessage(encryptedMessage) { + if (encryptedMessage.encrypted && encryptedMessage.iv) { + const decrypted = await decryptPayload( + encryptedMessage.encrypted, + encryptedMessage.iv, + sharedSecret + ); + + const response = JSON.parse(decrypted); + console.log("Received response:", response); + } +} +``` + +## Example Implementations + +### JavaScript/Node.js Client + +```javascript +import crypto from "crypto"; + +class ZaparooSecureClient { + constructor(host, port = 7497) { + this.baseUrl = `http://${host}:${port}`; + this.wsUrl = `ws://${host}:${port}`; + this.sequenceNumber = 0; + this.authToken = null; + this.sharedSecret = null; + } + + // Pairing process + async pair(clientName, verifier) { + // Step 1: Initiate pairing + const initResponse = await fetch(`${this.baseUrl}/api/pair/initiate`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ clientName }), + }); + + const { pairingToken } = await initResponse.json(); + + // Step 2: Complete pairing (verifier obtained out-of-band) + const completeResponse = await fetch(`${this.baseUrl}/api/pair/complete`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + pairingToken, + verifier, + clientName, + }), + }); + + const result = await completeResponse.json(); + this.authToken = result.authToken; + this.sharedSecret = Buffer.from(result.sharedSecret, "hex"); + + return result; + } + + // Encryption utilities + async encryptPayload(data) { + const iv = crypto.randomBytes(12); + const cipher = crypto.createCipherGCM("aes-256-gcm", this.sharedSecret); + cipher.setAAD(Buffer.alloc(0)); + + const encrypted = Buffer.concat([ + cipher.update(data, "utf8"), + cipher.final(), + ]); + + const authTag = cipher.getAuthTag(); + const ciphertext = Buffer.concat([encrypted, authTag]); + + return { + ciphertext: ciphertext.toString("base64"), + iv: iv.toString("base64"), + }; + } + + async decryptPayload(encryptedB64, ivB64) { + const encrypted = Buffer.from(encryptedB64, "base64"); + const iv = Buffer.from(ivB64, "base64"); + + const authTag = encrypted.slice(-16); + const ciphertext = encrypted.slice(0, -16); + + const decipher = crypto.createDecipherGCM("aes-256-gcm", this.sharedSecret); + decipher.setAuthTag(authTag); + decipher.setAAD(Buffer.alloc(0)); + + const decrypted = Buffer.concat([ + decipher.update(ciphertext), + decipher.final(), + ]); + + return decrypted.toString("utf8"); + } + + generateNonce() { + return crypto.randomBytes(16).toString("hex"); + } + + // HTTP API request + async request(method, params = {}) { + const payload = { + jsonrpc: "2.0", + method, + id: Date.now(), + params, + seq: ++this.sequenceNumber, + nonce: this.generateNonce(), + }; + + const encrypted = await this.encryptPayload(JSON.stringify(payload)); + + const response = await fetch(`${this.baseUrl}/api/v0.1`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + encrypted: encrypted.ciphertext, + iv: encrypted.iv, + authToken: this.authToken, + }), + }); + + const encryptedResponse = await response.json(); + const decryptedData = await this.decryptPayload( + encryptedResponse.encrypted, + encryptedResponse.iv + ); + + return JSON.parse(decryptedData); + } +} + +// Usage example +const client = new ZaparooSecureClient("192.168.1.100"); + +// Pair with server (verifier obtained from QR code/user input) +await client.pair("MyApp_v1.0", "verification-code-from-qr"); + +// Make API requests +const version = await client.request("system.version"); +const games = await client.request("search.games", { query: "zelda" }); +``` + +### Python Client + +```python +import asyncio +import json +import secrets +from cryptography.hazmat.primitives.ciphers.aead import AESGCM +import aiohttp +import websockets + +class ZaparooSecureClient: + def __init__(self, host, port=7497): + self.base_url = f"http://{host}:{port}" + self.ws_url = f"ws://{host}:{port}" + self.sequence_number = 0 + self.auth_token = None + self.shared_secret = None + + async def pair(self, client_name, verifier): + async with aiohttp.ClientSession() as session: + # Initiate pairing + async with session.post( + f"{self.base_url}/api/pair/initiate", + json={"clientName": client_name} + ) as response: + result = await response.json() + pairing_token = result["pairingToken"] + + # Complete pairing + async with session.post( + f"{self.base_url}/api/pair/complete", + json={ + "pairingToken": pairing_token, + "verifier": verifier, + "clientName": client_name + } + ) as response: + result = await response.json() + self.auth_token = result["authToken"] + self.shared_secret = bytes.fromhex(result["sharedSecret"]) + return result + + def encrypt_payload(self, data): + aesgcm = AESGCM(self.shared_secret) + nonce = secrets.token_bytes(12) + ciphertext = aesgcm.encrypt(nonce, data.encode(), None) + + return { + "ciphertext": ciphertext.hex(), + "iv": nonce.hex() + } + + def decrypt_payload(self, encrypted_hex, iv_hex): + aesgcm = AESGCM(self.shared_secret) + ciphertext = bytes.fromhex(encrypted_hex) + nonce = bytes.fromhex(iv_hex) + + plaintext = aesgcm.decrypt(nonce, ciphertext, None) + return plaintext.decode() + + def generate_nonce(self): + return secrets.token_hex(16) + + async def request(self, method, params=None): + if params is None: + params = {} + + payload = { + "jsonrpc": "2.0", + "method": method, + "id": 1, + "params": params, + "seq": self.sequence_number + 1, + "nonce": self.generate_nonce() + } + self.sequence_number += 1 + + encrypted = self.encrypt_payload(json.dumps(payload)) + + async with aiohttp.ClientSession() as session: + async with session.post( + f"{self.base_url}/api/v0.1", + json={ + "encrypted": encrypted["ciphertext"], + "iv": encrypted["iv"], + "authToken": self.auth_token + } + ) as response: + encrypted_response = await response.json() + decrypted = self.decrypt_payload( + encrypted_response["encrypted"], + encrypted_response["iv"] + ) + return json.loads(decrypted) + +# Usage +client = ZaparooSecureClient("192.168.1.100") +await client.pair("MyPythonApp", "verification-code") +result = await client.request("system.version") +``` + +## Error Handling + +### Common Error Responses + +| Error | HTTP Status | Description | +| ---------------------- | ----------- | -------------------------------- | +| Invalid auth token | 401 | Token not found or expired | +| Invalid request format | 400 | Malformed JSON or missing fields | +| Decryption failed | 400 | Invalid encryption or wrong key | +| Invalid sequence | 400 | Replay attack detected | +| Method not allowed | 405 | Wrong HTTP method | + +### Example Error Response + +```json +{ + "jsonrpc": "2.0", + "id": 1, + "error": { + "code": -32600, + "message": "Invalid Request" + } +} +``` + +### Client-Side Error Handling + +```javascript +try { + const result = await client.request("system.version"); + console.log(result); +} catch (error) { + if (error.response?.status === 401) { + console.log("Authentication failed - re-pair required"); + await client.pair(clientName, newVerifier); + } else if (error.response?.status === 400) { + console.log("Request error:", error.response.data); + } else { + console.log("Network error:", error.message); + } +} +``` + +## Security Best Practices + +### For Client Developers + +1. **Secure Key Storage**: Store shared secrets securely (keychain, encrypted storage) +2. **Sequence Management**: Persist sequence numbers across app restarts +3. **Nonce Uniqueness**: Generate cryptographically random nonces +4. **Error Handling**: Don't expose sensitive information in logs +5. **Network Security**: Use TLS for additional transport security +6. **Token Rotation**: Implement re-pairing for long-running applications + +### Example Secure Storage (Node.js) + +```javascript +import keytar from "keytar"; + +class SecureStorage { + static async storeCredentials( + clientId, + authToken, + sharedSecret, + sequenceNumber + ) { + await keytar.setPassword("zaparoo", `${clientId}-auth`, authToken); + await keytar.setPassword("zaparoo", `${clientId}-secret`, sharedSecret); + await keytar.setPassword( + "zaparoo", + `${clientId}-seq`, + sequenceNumber.toString() + ); + } + + static async loadCredentials(clientId) { + const authToken = await keytar.getPassword("zaparoo", `${clientId}-auth`); + const sharedSecret = await keytar.getPassword( + "zaparoo", + `${clientId}-secret` + ); + const sequenceNumber = parseInt( + (await keytar.getPassword("zaparoo", `${clientId}-seq`)) || "0" + ); + + return { authToken, sharedSecret, sequenceNumber }; + } +} +``` + +### Sequence Number Management + +```javascript +class SequenceManager { + constructor(clientId) { + this.clientId = clientId; + this.sequenceNumber = this.loadSequenceNumber(); + } + + getNext() { + this.sequenceNumber++; + this.saveSequenceNumber(); + return this.sequenceNumber; + } + + loadSequenceNumber() { + // Load from persistent storage + return parseInt( + localStorage.getItem(`zaparoo-seq-${this.clientId}`) || "0" + ); + } + + saveSequenceNumber() { + localStorage.setItem( + `zaparoo-seq-${this.clientId}`, + this.sequenceNumber.toString() + ); + } +} +``` + +## API Endpoints Reference + +### Available Endpoints + +| Endpoint | Authentication | Description | +| ------------------------- | ----------------- | ------------------------ | +| `POST /api/pair/initiate` | None | Start pairing process | +| `POST /api/pair/complete` | None | Complete pairing process | +| `POST /api/v0.1` | Required (remote) | JSON-RPC API | +| `WS /api/v0.1` | Required (remote) | WebSocket API | + +### Supported Methods + +Once authenticated, all standard Zaparoo Core JSON-RPC methods are available: + +- `system.version` - Get system version +- `system.heartbeat` - Health check +- `search.games` - Search game database +- `launch.game` - Launch a game +- `media.scan` - Trigger media scan +- And more... (see main API documentation) + +--- + +**Note**: This secure authentication layer is only required for remote connections. Localhost connections (127.0.0.1, ::1) continue to work without authentication for development and local tool access. diff --git a/go.mod b/go.mod index 9a19ffedb..4d1dedcb7 100755 --- a/go.mod +++ b/go.mod @@ -36,6 +36,7 @@ require ( go.bug.st/serial v1.6.4 go.etcd.io/bbolt v1.4.0 golang.design/x/clipboard v0.7.0 + golang.org/x/crypto v0.38.0 golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6 golang.org/x/sys v0.35.0 golang.org/x/text v0.26.0 @@ -54,6 +55,8 @@ require ( github.com/geoffgarside/ber v1.1.0 // indirect github.com/go-toast/toast v0.0.0-20190211030409-01e6764cf0a4 // indirect github.com/godbus/dbus/v5 v5.1.0 // indirect + github.com/gorilla/csrf v1.7.3 // indirect + github.com/gorilla/securecookie v1.1.2 // indirect github.com/hashicorp/go-uuid v1.0.3 // indirect github.com/jcmturner/aescts/v2 v2.0.0 // indirect github.com/jcmturner/dnsutils/v2 v2.0.0 // indirect @@ -73,8 +76,8 @@ require ( github.com/sethvargo/go-retry v0.3.0 // indirect github.com/stretchr/objx v0.5.2 // indirect github.com/tadvi/systray v0.0.0-20190226123456-11a2b8fa57af // indirect + github.com/unrolled/secure v1.17.0 // indirect go.uber.org/multierr v1.11.0 // indirect - golang.org/x/crypto v0.38.0 // indirect golang.org/x/exp/shiny v0.0.0-20250408133849-7e4ce0ab07d0 // indirect golang.org/x/image v0.20.0 // indirect golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a // indirect diff --git a/go.sum b/go.sum index c2abad4e7..bbe7e1339 100644 --- a/go.sum +++ b/go.sum @@ -59,8 +59,12 @@ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/csrf v1.7.3 h1:BHWt6FTLZAb2HtWT5KDBf6qgpZzvtbp9QWDRKZMXJC0= +github.com/gorilla/csrf v1.7.3/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk= github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= +github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= +github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI= github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= @@ -143,6 +147,8 @@ github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOf github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tadvi/systray v0.0.0-20190226123456-11a2b8fa57af h1:6yITBqGTE2lEeTPG04SN9W+iWHCRyHqlVYILiSXziwk= github.com/tadvi/systray v0.0.0-20190226123456-11a2b8fa57af/go.mod h1:4F09kP5F+am0jAwlQLddpoMDM+iewkxxt6nxUQ5nq5o= +github.com/unrolled/secure v1.17.0 h1:Io7ifFgo99Bnh0J7+Q+qcMzWM6kaDPCA5FroFZEdbWU= +github.com/unrolled/secure v1.17.0/go.mod h1:BmF5hyM6tXczk3MpQkFf1hpKSRqCyhqcbiQtiAF7+40= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.bug.st/serial v1.6.4 h1:7FmqNPgVp3pu2Jz5PoPtbZ9jJO5gnEnZIvnI1lzve8A= go.bug.st/serial v1.6.4/go.mod h1:nofMJxTeNVny/m6+KaafC6vJGj3miwQZ6vW4BZUGJPI= diff --git a/pkg/api/middleware/auth.go b/pkg/api/middleware/auth.go new file mode 100644 index 000000000..bcafa3b39 --- /dev/null +++ b/pkg/api/middleware/auth.go @@ -0,0 +1,399 @@ +// Zaparoo Core +// Copyright (c) 2025 The Zaparoo Project Contributors. +// SPDX-License-Identifier: GPL-3.0-or-later +// +// This file is part of Zaparoo Core. +// +// Zaparoo Core is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Zaparoo Core is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Zaparoo Core. If not, see . + +package middleware + +import ( + "bytes" + "context" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "sync" + "time" + + "github.com/ZaparooProject/zaparoo-core/v2/pkg/database" + "github.com/rs/zerolog/log" +) + +const ( + MutexCleanupInterval = 10 * time.Minute // Cleanup unused mutexes every 10 minutes + MutexMaxIdle = 30 * time.Minute // Remove mutexes unused for 30 minutes +) + +type clientKey string + +// ClientMutexManager handles per-client locking to prevent race conditions +// in authentication state updates +type ClientMutexManager struct { + mutexes sync.Map // map[string]*clientMutex +} + +// ClientMutex represents a per-client mutex for thread-safe operations +type ClientMutex struct { + lastUsed time.Time + clientID string + mu sync.Mutex +} + +// Legacy alias for backward compatibility +type clientMutex = ClientMutex + +var globalClientMutexManager = &ClientMutexManager{} + +type EncryptedRequest struct { + Encrypted string `json:"encrypted"` + IV string `json:"iv"` + AuthToken string `json:"authToken"` +} + +type DecryptedPayload struct { + ID any `json:"id,omitempty"` + JSONRPC string `json:"jsonrpc"` + Method string `json:"method"` + Nonce string `json:"nonce"` + Params json.RawMessage `json:"params,omitempty"` + Seq uint64 `json:"seq"` +} + +func isLocalhost(remoteAddr string) bool { + host, _, err := net.SplitHostPort(remoteAddr) + if err != nil { + host = remoteAddr + } + + ip := net.ParseIP(host) + if ip == nil { + return host == "localhost" + } + + return ip.IsLoopback() +} + +func AuthMiddleware(db *database.Database) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Skip authentication for localhost connections + if isLocalhost(r.RemoteAddr) { + log.Debug().Str("remote_addr", r.RemoteAddr).Msg("localhost connection - skipping auth") + next.ServeHTTP(w, r) + return + } + + // Read request body + body, err := io.ReadAll(r.Body) + if err != nil { + log.Error().Err(err).Msg("failed to read request body") + http.Error(w, "Failed to read request body", http.StatusInternalServerError) + return + } + + // Try to parse as encrypted request + var encReq EncryptedRequest + if parseErr := json.Unmarshal(body, &encReq); parseErr != nil { + log.Error().Err(parseErr).Msg("invalid encrypted request format") + http.Error(w, "Invalid request format", http.StatusBadRequest) + return + } + + // First, validate auth token and get client for initial validation + initialClient, err := db.UserDB.GetClientByAuthToken(encReq.AuthToken) + if err != nil { + tokenStr := "empty" + if len(encReq.AuthToken) >= 8 { + tokenStr = encReq.AuthToken[:8] + "..." + } else if encReq.AuthToken != "" { + tokenStr = encReq.AuthToken + } + log.Warn().Err(err). + Str("token", tokenStr). + Str("remote_addr", r.RemoteAddr). + Str("user_agent", r.Header.Get("User-Agent")). + Msg("SECURITY: invalid auth token - potential attack") + http.Error(w, "Invalid auth token", http.StatusUnauthorized) + return + } + + // Acquire client lock BEFORE any decryption or validation + // to prevent race conditions between concurrent requests + unlockClient := LockClient(initialClient.ClientID) + defer unlockClient() + + // Re-fetch client state under lock to get latest sequence/nonce state + client, err := db.UserDB.GetClientByAuthToken(encReq.AuthToken) + if err != nil { + log.Error().Err(err).Msg("failed to re-fetch client under lock") + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + // Decrypt payload with locked client data + decryptedPayload, err := DecryptPayload(encReq.Encrypted, encReq.IV, client.SharedSecret) + if err != nil { + log.Error().Err(err).Msg("failed to decrypt payload") + http.Error(w, "Decryption failed", http.StatusBadRequest) + return + } + + // Parse decrypted payload + var payload DecryptedPayload + if parseErr := json.Unmarshal(decryptedPayload, &payload); parseErr != nil { + log.Error().Err(parseErr).Msg("invalid decrypted payload format") + http.Error(w, "Invalid payload format", http.StatusBadRequest) + return + } + + // Create replay protector from client state + replayProtector := NewReplayProtector(client) + + // Validate sequence number and nonce with locked client state + if !replayProtector.ValidateSequenceAndNonce(payload.Seq, payload.Nonce) { + log.Warn(). + Str("client_id", client.ClientID). + Uint64("seq", payload.Seq). + Str("nonce", payload.Nonce). + Str("remote_addr", r.RemoteAddr). + Str("user_agent", r.Header.Get("User-Agent")). + Msg("SECURITY: replay attack detected") + http.Error(w, "Invalid sequence or replay detected", http.StatusBadRequest) + return + } + + // Update replay protector state + replayProtector.UpdateSequenceAndNonce(payload.Seq, payload.Nonce) + + // Get updated state for database storage + currentSeq, seqWindow, nonceCache := replayProtector.GetStateForDatabase() + client.CurrentSeq = currentSeq + client.SeqWindow = seqWindow + client.NonceCache = nonceCache + + // Save to database (still under lock) + if updateErr := db.UserDB.UpdateClientSequence( + client.ClientID, client.CurrentSeq, client.SeqWindow, client.NonceCache, + ); updateErr != nil { + log.Error().Err(updateErr).Msg("failed to update client state") + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + // Replace request body with decrypted JSON-RPC payload + originalPayload := map[string]any{ + "jsonrpc": payload.JSONRPC, + "method": payload.Method, + "id": payload.ID, + } + if payload.Params != nil { + originalPayload["params"] = payload.Params + } + + newBody, err := json.Marshal(originalPayload) + if err != nil { + log.Error().Err(err).Msg("failed to marshal decrypted payload") + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + // Create new request with decrypted body + r.Body = io.NopCloser(bytes.NewReader(newBody)) + r.ContentLength = int64(len(newBody)) + + // Store client in context for potential use by handlers + ctx := context.WithValue(r.Context(), clientKey("client"), client) + r = r.WithContext(ctx) + + log.Info(). + Str("client_id", client.ClientID). + Str("method", payload.Method). + Uint64("seq", payload.Seq). + Str("remote_addr", r.RemoteAddr). + Msg("SECURITY: authenticated request processed") + + next.ServeHTTP(w, r) + }) + } +} + +func DecryptPayload(encryptedB64, ivB64 string, key []byte) ([]byte, error) { + // Decode base64 + encrypted, err := base64.StdEncoding.DecodeString(encryptedB64) + if err != nil { + return nil, fmt.Errorf("invalid encrypted data: %w", err) + } + + iv, err := base64.StdEncoding.DecodeString(ivB64) + if err != nil { + return nil, fmt.Errorf("invalid IV: %w", err) + } + + // Create AES cipher + block, err := aes.NewCipher(key) + if err != nil { + return nil, fmt.Errorf("failed to create cipher: %w", err) + } + + // Create GCM mode + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, fmt.Errorf("failed to create GCM: %w", err) + } + + // Decrypt + plaintext, err := gcm.Open(nil, iv, encrypted, nil) + if err != nil { + return nil, fmt.Errorf("decryption failed: %w", err) + } + + return plaintext, nil +} + +func EncryptPayload(data, key []byte) (encrypted, iv string, err error) { + // Create AES cipher + block, err := aes.NewCipher(key) + if err != nil { + return "", "", fmt.Errorf("failed to create cipher: %w", err) + } + + // Create GCM mode + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", "", fmt.Errorf("failed to create GCM: %w", err) + } + + // Generate random IV + ivBytes := make([]byte, gcm.NonceSize()) + if _, err := io.ReadFull(rand.Reader, ivBytes); err != nil { + return "", "", fmt.Errorf("failed to generate IV: %w", err) + } + + // Encrypt + ciphertext := gcm.Seal(nil, ivBytes, data, nil) + + // Return base64 encoded values + return base64.StdEncoding.EncodeToString(ciphertext), + base64.StdEncoding.EncodeToString(ivBytes), nil +} + +// getClientMutex retrieves or creates a mutex for the given client ID +func (cm *ClientMutexManager) getClientMutex(clientID string) *ClientMutex { + // Try to load existing mutex + if value, exists := cm.mutexes.Load(clientID); exists { + mutex, ok := value.(*clientMutex) + if !ok { + log.Error().Str("client_id", clientID).Msg("invalid mutex type in cache") + return nil + } + mutex.lastUsed = time.Now() + return mutex + } + + // Create new mutex + newMutex := &ClientMutex{ + lastUsed: time.Now(), + clientID: clientID, + } + + // Store and return the mutex (LoadOrStore handles race conditions) + actual, _ := cm.mutexes.LoadOrStore(clientID, newMutex) + actualMutex, ok := actual.(*clientMutex) + if !ok { + log.Error().Str("client_id", clientID).Msg("invalid mutex type after LoadOrStore") + return nil + } + actualMutex.lastUsed = time.Now() + return actualMutex +} + +// lockClient acquires a lock for the specified client, preventing race conditions +// in authentication state updates. The returned function must be called to release the lock. +func (cm *ClientMutexManager) lockClient(clientID string) func() { + mutex := cm.getClientMutex(clientID) + mutex.mu.Lock() + + return func() { + mutex.mu.Unlock() + } +} + +// cleanup removes unused mutexes to prevent memory leaks +func (cm *ClientMutexManager) cleanup() { + now := time.Now() + cm.mutexes.Range(func(key, value any) bool { + mutex, ok := value.(*clientMutex) + if !ok { + log.Error().Interface("key", key).Msg("invalid mutex type in cleanup") + return true + } + if now.Sub(mutex.lastUsed) > MutexMaxIdle { + cm.mutexes.Delete(key) + log.Debug().Str("client_id", mutex.clientID).Msg("cleaned up unused client mutex") + } + return true + }) +} + +// StartCleanupRoutine starts a background goroutine to periodically clean up unused mutexes +func (cm *ClientMutexManager) StartCleanupRoutine(ctx context.Context) { + go func() { + ticker := time.NewTicker(MutexCleanupInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + log.Debug().Msg("client mutex cleanup routine stopped") + return + case <-ticker.C: + cm.cleanup() + } + } + }() +} + +// GetClientMutex is a convenience function to get a mutex for a client +func GetClientMutex(clientID string) *ClientMutex { + return globalClientMutexManager.getClientMutex(clientID) +} + +// LockClient is a convenience function to lock a client +func LockClient(clientID string) func() { + return globalClientMutexManager.lockClient(clientID) +} + +// StartGlobalMutexCleanup starts the global mutex cleanup routine +func StartGlobalMutexCleanup(ctx context.Context) { + globalClientMutexManager.StartCleanupRoutine(ctx) +} + +func IsAuthenticatedConnection(r *http.Request) bool { + return !isLocalhost(r.RemoteAddr) +} + +func GetClientFromContext(ctx context.Context) *database.Client { + if client, ok := ctx.Value(clientKey("client")).(*database.Client); ok { + return client + } + return nil +} diff --git a/pkg/api/middleware/auth_test.go b/pkg/api/middleware/auth_test.go new file mode 100644 index 000000000..96325d943 --- /dev/null +++ b/pkg/api/middleware/auth_test.go @@ -0,0 +1,638 @@ +// Zaparoo Core +// Copyright (c) 2025 The Zaparoo Project Contributors. +// SPDX-License-Identifier: GPL-3.0-or-later +// +// This file is part of Zaparoo Core. +// +// Zaparoo Core is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Zaparoo Core is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Zaparoo Core. If not, see . + +package middleware + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/ZaparooProject/zaparoo-core/v2/pkg/database" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/testing/helpers" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +//nolint:paralleltest,tparallel // Security tests require deterministic mock validation +func TestAuthMiddleware_LocalhostBypass(t *testing.T) { + t.Parallel() + // Setup + userDB := helpers.NewMockUserDBI() + db := &database.Database{UserDB: userDB} + + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("success")) + }) + + middleware := AuthMiddleware(db) + wrappedHandler := middleware(handler) + + tests := []struct { + name string + remoteAddr string + expectPass bool + }{ + {"localhost with port", "127.0.0.1:12345", true}, + {"localhost without port", "127.0.0.1", true}, + {"localhost name with port", "localhost:8080", true}, + {"localhost name without port", "localhost", true}, + {"IPv6 loopback", "[::1]:8080", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/api/test", bytes.NewReader([]byte(`{"test": "data"}`))) + req.RemoteAddr = tt.remoteAddr + + w := httptest.NewRecorder() + wrappedHandler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code, "localhost should bypass auth") + assert.Equal(t, "success", w.Body.String()) + }) + } +} + +//nolint:paralleltest,tparallel // Security tests require deterministic mock validation +func TestAuthMiddleware_RemoteRequiresAuth(t *testing.T) { + t.Parallel() + // Setup + userDB := helpers.NewMockUserDBI() + // Mock the GetClientByAuthToken call to return an error for empty token + // The middleware calls this once per test case (we have 2 test cases) + userDB.On("GetClientByAuthToken", "").Return((*database.Client)(nil), assert.AnError).Times(2) + + db := &database.Database{UserDB: userDB} + + handler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + t.Error("handler should not be called for unauthenticated remote requests") + }) + + middleware := AuthMiddleware(db) + wrappedHandler := middleware(handler) + + tests := []struct { + name string + remoteAddr string + }{ + {"private network IP", "192.168.1.100:5000"}, + {"public IP", "8.8.8.8:80"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Send regular JSON without proper auth fields - should fail at auth token lookup + req := httptest.NewRequest(http.MethodPost, "/api/test", bytes.NewReader([]byte(`{"test": "data"}`))) + req.RemoteAddr = tt.remoteAddr + + w := httptest.NewRecorder() + wrappedHandler.ServeHTTP(w, req) + + // Should fail because auth token is empty/invalid + assert.Equal(t, http.StatusUnauthorized, w.Code, "remote should fail auth") + assert.Contains(t, w.Body.String(), "Invalid auth token") + }) + } + + userDB.AssertExpectations(t) +} + +func TestAuthMiddleware_EncryptedRequest(t *testing.T) { + t.Parallel() + // Create a mock device with known shared secret + testSecret := []byte("test-secret-key-32-bytes-long-ok") + testDevice := &database.Client{ + ClientID: "test-device-id", + ClientName: "Test Device", + AuthTokenHash: "test-token-hash", + SharedSecret: testSecret, + CurrentSeq: 0, + SeqWindow: make([]byte, 8), + NonceCache: []string{}, + CreatedAt: time.Now(), + LastSeen: time.Now(), + } + + userDB := helpers.NewMockUserDBI() + userDB.On("GetClientByAuthToken", "test-auth-token").Return(testDevice, nil) + userDB.On("UpdateClientSequence", "test-device-id", uint64(1), + mock.AnythingOfType("[]uint8"), mock.AnythingOfType("[]string")).Return(nil) + + db := &database.Database{UserDB: userDB} + + // Create encrypted payload + payload := DecryptedPayload{ + JSONRPC: "2.0", + Method: "test.method", + ID: 1, + Seq: 1, + Nonce: "test-nonce-123", + } + + payloadJSON, err := json.Marshal(payload) + require.NoError(t, err) + + encryptedData, iv, err := EncryptPayload(payloadJSON, testSecret) + require.NoError(t, err) + + encRequest := EncryptedRequest{ + Encrypted: encryptedData, + IV: iv, + AuthToken: "test-auth-token", + } + + encRequestJSON, err := json.Marshal(encRequest) + require.NoError(t, err) + + // Test handler that checks the decrypted request + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify the request body was decrypted properly + var jsonRPC map[string]any + err := json.NewDecoder(r.Body).Decode(&jsonRPC) + assert.NoError(t, err) + + assert.Equal(t, "2.0", jsonRPC["jsonrpc"]) + assert.Equal(t, "test.method", jsonRPC["method"]) + assert.InDelta(t, float64(1), jsonRPC["id"], 0.001) // JSON unmarshals numbers as float64 + + // Verify device is in context + device := GetClientFromContext(r.Context()) + assert.NotNil(t, device) + assert.Equal(t, "test-device-id", device.ClientID) + + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("authenticated success")) + }) + + middleware := AuthMiddleware(db) + wrappedHandler := middleware(handler) + + // Test successful authentication + req := httptest.NewRequest(http.MethodPost, "/api/test", bytes.NewReader(encRequestJSON)) + req.RemoteAddr = "192.168.1.100:5000" // Remote address to trigger auth + + w := httptest.NewRecorder() + wrappedHandler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "authenticated success", w.Body.String()) + userDB.AssertExpectations(t) +} + +//nolint:paralleltest,tparallel // Security tests require deterministic mock validation +func TestReplayProtector_ReplayProtection(t *testing.T) { + t.Parallel() + tests := []struct { + name string + newNonce string + description string + seqWindow []byte + nonceCache []string + currentSeq uint64 + newSeq uint64 + expectedResult bool + }{ + { + name: "first message", + currentSeq: 0, + seqWindow: make([]byte, 8+128*8), // Ring buffer: 8 bytes + 128 blocks * 8 bytes + nonceCache: []string{}, + newSeq: 1, + newNonce: "nonce1", + expectedResult: true, + description: "first message should always pass", + }, + { + name: "sequence increment", + currentSeq: 5, + seqWindow: make([]byte, 8+128*8), + nonceCache: []string{"old-nonce"}, + newSeq: 6, + newNonce: "nonce6", + expectedResult: true, + description: "incrementing sequence should pass", + }, + { + name: "duplicate nonce", + currentSeq: 5, + seqWindow: make([]byte, 8+128*8), + nonceCache: []string{"duplicate-nonce"}, + newSeq: 6, + newNonce: "duplicate-nonce", + expectedResult: false, + description: "duplicate nonce should be rejected", + }, + { + name: "old sequence far out of window", + currentSeq: 50000, + seqWindow: make([]byte, 8+128*8), + nonceCache: []string{}, + newSeq: 100, // More than 8000+ behind (outside WireGuard window) + newNonce: "nonce100", + expectedResult: false, + description: "sequence too far behind should be rejected", + }, + { + name: "sequence within large window", + currentSeq: 1000, + seqWindow: make([]byte, 8+128*8), + nonceCache: []string{}, + newSeq: 950, // Within large window + newNonce: "nonce950", + expectedResult: true, + description: "sequence within sliding window should pass", + }, + { + name: "large sequence jump forward", + currentSeq: 5, + seqWindow: make([]byte, 8+128*8), + nonceCache: []string{}, + newSeq: 1000, // Large jump forward + newNonce: "nonce1000", + expectedResult: true, + description: "large sequence jump should be accepted", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := &database.Client{ + ClientID: "test-device", + CurrentSeq: tt.currentSeq, + SeqWindow: tt.seqWindow, + NonceCache: tt.nonceCache, + } + + replayProtector := NewReplayProtector(client) + result := replayProtector.ValidateSequenceAndNonce(tt.newSeq, tt.newNonce) + assert.Equal(t, tt.expectedResult, result, tt.description) + }) + } +} + +func TestEncryptDecryptPayload(t *testing.T) { + t.Parallel() + testKey := []byte("test-encryption-key-32bytes-ok!!") + originalData := []byte(`{"jsonrpc":"2.0","method":"test","id":1}`) + + // Test encryption + encrypted, iv, err := EncryptPayload(originalData, testKey) + require.NoError(t, err) + assert.NotEmpty(t, encrypted) + assert.NotEmpty(t, iv) + + // Test decryption + decrypted, err := DecryptPayload(encrypted, iv, testKey) + require.NoError(t, err) + assert.Equal(t, originalData, decrypted) +} + +func TestEncryptDecryptPayload_WrongKey(t *testing.T) { + t.Parallel() + correctKey := []byte("correct-key-32-bytes-long-ok!!!!") + wrongKey := []byte("wrong-key-32-bytes-long-ok!!!!!!") + originalData := []byte(`{"test": "data"}`) + + // Encrypt with correct key + encrypted, iv, err := EncryptPayload(originalData, correctKey) + require.NoError(t, err) + + // Try to decrypt with wrong key - should fail + _, err = DecryptPayload(encrypted, iv, wrongKey) + require.Error(t, err, "decryption should fail with wrong key") + assert.Contains(t, err.Error(), "decryption failed") +} + +//nolint:paralleltest,tparallel // Security tests require deterministic execution order +func TestIsLocalhost(t *testing.T) { + t.Parallel() + tests := []struct { + addr string + expected bool + }{ + {"127.0.0.1:8080", true}, + {"127.0.0.1", true}, + {"localhost:3000", true}, + {"localhost", true}, + {"[::1]:8080", true}, + {"::1", true}, + {"192.168.1.1:8080", false}, + {"8.8.8.8:53", false}, + {"example.com:80", false}, + } + + for _, tt := range tests { + t.Run(tt.addr, func(t *testing.T) { + result := isLocalhost(tt.addr) + assert.Equal(t, tt.expected, result) + }) + } +} + +//nolint:paralleltest,tparallel // Security tests require deterministic mock validation +func TestAuthMiddleware_InvalidRequests(t *testing.T) { + t.Parallel() + userDB := helpers.NewMockUserDBI() + // Mock empty auth token lookup - expect it to be called once for the missing auth token test + userDB.On("GetClientByAuthToken", "").Return((*database.Client)(nil), assert.AnError).Once() + + db := &database.Database{UserDB: userDB} + + handler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + t.Error("handler should not be called for invalid requests") + }) + + middleware := AuthMiddleware(db) + wrappedHandler := middleware(handler) + + tests := []struct { + name string + body string + description string + expectedCode int + }{ + { + name: "invalid json", + body: `{"invalid json`, + expectedCode: http.StatusBadRequest, + description: "malformed JSON should be rejected", + }, + { + name: "missing auth token", + body: `{"encrypted": "data", "iv": "iv"}`, + expectedCode: http.StatusUnauthorized, + description: "missing auth token should be rejected", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/api/test", bytes.NewReader([]byte(tt.body))) + req.RemoteAddr = "192.168.1.100:5000" // Remote address to trigger auth + + w := httptest.NewRecorder() + wrappedHandler.ServeHTTP(w, req) + + assert.Equal(t, tt.expectedCode, w.Code, tt.description) + }) + } + + userDB.AssertExpectations(t) +} + +func TestGetClientFromContext(t *testing.T) { + t.Parallel() + // Test with client in context + client := &database.Client{ClientID: "test-device"} + ctx := context.WithValue(context.Background(), clientKey("client"), client) + + result := GetClientFromContext(ctx) + assert.Equal(t, client, result) + + // Test with no client in context + emptyCtx := context.Background() + result = GetClientFromContext(emptyCtx) + assert.Nil(t, result) + + // Test with wrong type in context + badCtx := context.WithValue(context.Background(), clientKey("client"), "not-a-client") + result = GetClientFromContext(badCtx) + assert.Nil(t, result) +} + +// TestAuthMiddleware_ConcurrentRequests verifies that the race condition fix +// prevents concurrent requests from bypassing replay protection +func TestAuthMiddleware_ConcurrentRequests(t *testing.T) { + t.Parallel() + + // This test verifies the mutex locking works correctly + // We'll just test that concurrent access to the mutex manager is safe + + const numConcurrentRequests = 20 + const deviceID = "test-device-concurrent" + + done := make(chan struct{}, numConcurrentRequests) + var lockAcquired int32 + + for range numConcurrentRequests { + go func() { + defer func() { done <- struct{}{} }() + + // Acquire device lock - this should be thread-safe + unlockDevice := LockClient(deviceID) + + // Critical section - only one goroutine should be here at a time + current := atomic.AddInt32(&lockAcquired, 1) + if current != 1 { + t.Errorf("Race condition detected: %d goroutines in critical section", current) + } + + // Simulate some work + time.Sleep(1 * time.Millisecond) + + atomic.AddInt32(&lockAcquired, -1) + unlockDevice() + }() + } + + // Wait for all requests to complete + for range numConcurrentRequests { + <-done + } + + // Verify no race conditions occurred + assert.Equal(t, int32(0), atomic.LoadInt32(&lockAcquired), "All locks should be released") +} + +// TestAuthMiddleware_ConcurrentAuthentication tests actual concurrent authentication +// requests to ensure the race condition fix prevents replay attacks +func TestAuthMiddleware_ConcurrentAuthentication(t *testing.T) { + t.Parallel() + + // Setup + userDB := helpers.NewMockUserDBI() + db := &database.Database{UserDB: userDB} + + // Create a test client + testClient := &database.Client{ + ClientID: "test-client-concurrent", + ClientName: "Test Client", + AuthTokenHash: "test-hash", + SharedSecret: []byte("test-secret-32-bytes-long-key!!!"), + CurrentSeq: 5, + SeqWindow: make([]byte, 8), + NonceCache: []string{}, + CreatedAt: time.Now(), + LastSeen: time.Now(), + } + + // Mock database calls + userDB.On("GetClientByAuthToken", "test-token").Return(testClient, nil) + userDB.On("UpdateClientSequence", + testClient.ClientID, + mock.AnythingOfType("uint64"), + mock.AnythingOfType("[]uint8"), + mock.AnythingOfType("[]string")).Return(nil) + + // Create test handler that tracks successful authentications + var successCount int32 + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + atomic.AddInt32(&successCount, 1) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("success")) + }) + + middleware := AuthMiddleware(db) + wrappedHandler := middleware(handler) + + // Create encrypted request with same sequence number (should cause replay detection) + payload := map[string]any{ + "jsonrpc": "2.0", + "method": "test.method", + "id": 1, + "nonce": "test-nonce-concurrent", + "seq": uint64(6), // Same sequence for all requests + } + payloadJSON, _ := json.Marshal(payload) + + encrypted, iv, err := EncryptPayload(payloadJSON, testClient.SharedSecret) + require.NoError(t, err) + + encRequest := map[string]string{ + "encrypted": encrypted, + "iv": iv, + "authToken": "test-token", + } + requestBody, _ := json.Marshal(encRequest) + + // Run concurrent requests + const numRequests = 10 + done := make(chan int, numRequests) + + for i := range numRequests { + go func(_ int) { + req := httptest.NewRequest(http.MethodPost, "/api/test", bytes.NewReader(requestBody)) + req.RemoteAddr = "192.168.1.100:12345" // Remote address to trigger auth + + w := httptest.NewRecorder() + wrappedHandler.ServeHTTP(w, req) + + done <- w.Code + }(i) + } + + // Collect results + statusCodes := make([]int, 0, numRequests) + for range numRequests { + statusCodes = append(statusCodes, <-done) + } + + // Only ONE request should succeed (200), others should fail with replay detection (400) + successStatusCount := 0 + for _, code := range statusCodes { + if code == http.StatusOK { + successStatusCount++ + } else { + assert.Equal(t, http.StatusBadRequest, code, "Failed requests should return 400 for replay detection") + } + } + + assert.Equal(t, 1, successStatusCount, "Only one concurrent request should succeed") + assert.Equal(t, int32(1), atomic.LoadInt32(&successCount), "Handler should only be called once") + + userDB.AssertExpectations(t) +} + +// TestClientMutexManager_Cleanup verifies mutex cleanup works correctly +func TestClientMutexManager_Cleanup(t *testing.T) { + t.Parallel() + + dm := &ClientMutexManager{} + + // Create some mutexes + mutex1 := dm.getClientMutex("device1") + mutex2 := dm.getClientMutex("device2") + mutex3 := dm.getClientMutex("device3") + + require.NotNil(t, mutex1) + require.NotNil(t, mutex2) + require.NotNil(t, mutex3) + + // Age some mutexes + mutex1.lastUsed = time.Now().Add(-31 * time.Minute) // Should be cleaned + mutex2.lastUsed = time.Now().Add(-20 * time.Minute) // Should remain + mutex3.lastUsed = time.Now() // Should remain + + // Run cleanup + dm.cleanup() + + // Check that only old mutex was removed + _, exists1 := dm.mutexes.Load("device1") + _, exists2 := dm.mutexes.Load("device2") + _, exists3 := dm.mutexes.Load("device3") + + assert.False(t, exists1, "Old mutex should be cleaned up") + assert.True(t, exists2, "Recent mutex should remain") + assert.True(t, exists3, "Current mutex should remain") +} + +// TestClientMutexManager_ConcurrentAccess verifies thread safety of mutex manager +func TestClientMutexManager_ConcurrentAccess(t *testing.T) { + t.Parallel() + + dm := &ClientMutexManager{} + const numGoroutines = 50 + const deviceID = "concurrent-test-device" + + done := make(chan struct{}, numGoroutines) + + // Launch multiple goroutines that get/create mutex for same device + for range numGoroutines { + go func() { + defer func() { done <- struct{}{} }() + + mutex := dm.getClientMutex(deviceID) + assert.NotNil(t, mutex) + if mutex != nil { + assert.Equal(t, deviceID, mutex.clientID) + } + }() + } + + // Wait for all goroutines to complete + for range numGoroutines { + <-done + } + + // Verify only one mutex was created for the device + value, exists := dm.mutexes.Load(deviceID) + assert.True(t, exists, "Mutex should exist") + + mutex, ok := value.(*clientMutex) + require.True(t, ok) + assert.Equal(t, deviceID, mutex.clientID) +} diff --git a/pkg/api/middleware/ratelimit_test.go b/pkg/api/middleware/ratelimit_test.go new file mode 100644 index 000000000..891d7002b --- /dev/null +++ b/pkg/api/middleware/ratelimit_test.go @@ -0,0 +1,206 @@ +// Zaparoo Core +// Copyright (c) 2025 The Zaparoo Project Contributors. +// SPDX-License-Identifier: GPL-3.0-or-later +// +// This file is part of Zaparoo Core. +// +// Zaparoo Core is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Zaparoo Core is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Zaparoo Core. If not, see . + +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "golang.org/x/time/rate" +) + +func TestIPRateLimiter_BasicFunctionality(t *testing.T) { + t.Parallel() + limiter := NewIPRateLimiter() + + // Get limiter for IP + rl := limiter.GetLimiter("192.168.1.100") + assert.NotNil(t, rl) + + // Should allow initial requests up to burst size + for i := range BurstSize { + allowed := rl.Allow() + assert.True(t, allowed, "should allow request %d within burst", i+1) + } + + // Should block additional requests beyond burst + blocked := rl.Allow() + assert.False(t, blocked, "should block request beyond burst size") +} + +func TestIPRateLimiter_DifferentIPs(t *testing.T) { + t.Parallel() + limiter := NewIPRateLimiter() + + // Get limiters for different IPs + rl1 := limiter.GetLimiter("192.168.1.100") + rl2 := limiter.GetLimiter("192.168.1.101") + + // Should be different limiters + assert.NotSame(t, rl1, rl2) + + // Exhaust first limiter + for range BurstSize { + rl1.Allow() + } + + // First limiter should be blocked + assert.False(t, rl1.Allow()) + + // Second limiter should still allow requests + assert.True(t, rl2.Allow()) +} + +func TestIPRateLimiter_SameIPReuse(t *testing.T) { + t.Parallel() + limiter := NewIPRateLimiter() + + // Get limiter for same IP twice + rl1 := limiter.GetLimiter("192.168.1.100") + rl2 := limiter.GetLimiter("192.168.1.100") + + // Should be same limiter instance + assert.Same(t, rl1, rl2) +} + +func TestHTTPRateLimitMiddleware_Allow(t *testing.T) { + t.Parallel() + limiter := NewIPRateLimiter() + + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("success")) + }) + + middleware := HTTPRateLimitMiddleware(limiter) + wrappedHandler := middleware(handler) + + // Should allow initial requests + for i := range BurstSize { + req := httptest.NewRequest(http.MethodPost, "/api/test", http.NoBody) + req.RemoteAddr = "192.168.1.100:12345" + + w := httptest.NewRecorder() + wrappedHandler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code, "should allow request %d", i+1) + assert.Equal(t, "success", w.Body.String()) + } +} + +func TestHTTPRateLimitMiddleware_Block(t *testing.T) { + t.Parallel() + limiter := NewIPRateLimiter() + + handler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + t.Error("handler should not be called when rate limited") + }) + + middleware := HTTPRateLimitMiddleware(limiter) + wrappedHandler := middleware(handler) + + // Exhaust rate limit + ip := "192.168.1.100:12345" + ipLimiter := limiter.GetLimiter("192.168.1.100") + for range BurstSize { + ipLimiter.Allow() + } + + // Next request should be blocked + req := httptest.NewRequest(http.MethodPost, "/api/test", http.NoBody) + req.RemoteAddr = ip + + w := httptest.NewRecorder() + wrappedHandler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusTooManyRequests, w.Code) + assert.Contains(t, w.Body.String(), "Too Many Requests") +} + +func TestIPRateLimiter_Cleanup(t *testing.T) { + t.Parallel() + limiter := NewIPRateLimiter() + + // Add a limiter with old timestamp manually + limiter.limiters["old.ip"] = &rateLimiterEntry{ + limiter: rate.NewLimiter(rate.Limit(float64(RequestsPerMinute)/60.0), BurstSize), + lastSeen: time.Now().Add(-15 * time.Minute), // Old timestamp + } + + // Add a limiter with recent timestamp manually + limiter.limiters["new.ip"] = &rateLimiterEntry{ + limiter: rate.NewLimiter(rate.Limit(float64(RequestsPerMinute)/60.0), BurstSize), + lastSeen: time.Now(), // Recent timestamp + } + + // Verify both exist + assert.Len(t, limiter.limiters, 2) + + // Run cleanup + limiter.Cleanup() + + // Old entry should be removed, new one should remain + assert.Len(t, limiter.limiters, 1) + assert.Contains(t, limiter.limiters, "new.ip") + assert.NotContains(t, limiter.limiters, "old.ip") +} + +func TestHTTPRateLimitMiddleware_IPExtraction(t *testing.T) { + t.Parallel() + limiter := NewIPRateLimiter() + + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + middleware := HTTPRateLimitMiddleware(limiter) + wrappedHandler := middleware(handler) + + tests := []struct { + name string + remoteAddr string + expectedIP string + }{ + {"with port", "192.168.1.100:12345", "192.168.1.100"}, + {"without port", "192.168.1.100", "192.168.1.100"}, + {"IPv6 with port", "[2001:db8::1]:8080", "2001:db8::1"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + req := httptest.NewRequest(http.MethodPost, "/api/test", http.NoBody) + req.RemoteAddr = tt.remoteAddr + + w := httptest.NewRecorder() + wrappedHandler.ServeHTTP(w, req) + + // Should succeed (IP extraction worked) + assert.Equal(t, http.StatusOK, w.Code) + + // Verify the correct IP was used for rate limiting + ipLimiter := limiter.GetLimiter(tt.expectedIP) + assert.NotNil(t, ipLimiter) + }) + } +} diff --git a/pkg/api/middleware/replay_filter.go b/pkg/api/middleware/replay_filter.go new file mode 100644 index 000000000..7bf923dec --- /dev/null +++ b/pkg/api/middleware/replay_filter.go @@ -0,0 +1,93 @@ +// Zaparoo Core +// Copyright (c) 2025 The Zaparoo Project Contributors. +// SPDX-License-Identifier: GPL-3.0-or-later +// +// This file is part of Zaparoo Core. +// +// Zaparoo Core is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Zaparoo Core is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Zaparoo Core. If not, see . + +package middleware + +/* +ATTRIBUTION NOTICE: + +This file contains code derived from WireGuard's anti-replay filter implementation. + +Original work: Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. +Original license: MIT License +Source: https://github.com/WireGuard/wireguard-go/blob/master/replay/replay.go + +The derived portions remain under their original MIT license as required. +The combined work is distributed under GPL-3.0-or-later as indicated in the file header. +*/ + +type block uint64 + +const ( + blockBitLog = 6 + blockBits = 1 << blockBitLog + ringBlocks = 1 << 7 + windowSize = (ringBlocks - 1) * blockBits + blockMask = ringBlocks - 1 + bitMask = blockBits - 1 +) + +// Filter implements an anti-replay mechanism as specified in RFC 6479. +// It rejects replayed messages by checking if message counter value is within +// a sliding window of previously received messages. +type Filter struct { + last uint64 + ring [ringBlocks]block +} + +// Reset resets the filter to an empty state. +func (f *Filter) Reset() { + f.last = 0 + for i := range f.ring { + f.ring[i] = 0 + } +} + +// ValidateCounter checks if a given counter should be accepted. +// It automatically rejects counters >= the specified limit. +func (f *Filter) ValidateCounter(counter, limit uint64) bool { + if counter >= limit { + return false + } + + indexBlock := counter >> blockBitLog + + if counter > f.last { + // New highest counter - shift window + current := f.last >> blockBitLog + diff := indexBlock - current + if diff > ringBlocks { + diff = ringBlocks + } + for i := current + 1; i <= current+diff; i++ { + f.ring[i&blockMask] = 0 + } + f.last = counter + } else if f.last-counter > windowSize { + // Too old + return false + } + + indexBlock &= blockMask + indexBit := counter & bitMask + old := f.ring[indexBlock] + updated := old | 1<. + +package middleware + +import ( + "testing" +) + +const RejectAfterMessages = 1<<64 - 1<<13 - 1 + +func TestReplayFilter(t *testing.T) { + t.Parallel() + var filter Filter + const tLim = windowSize + 1 + testNumber := 0 + + testFunc := func(n uint64, expected bool) { + testNumber++ + if filter.ValidateCounter(n, RejectAfterMessages) != expected { + t.Fatal("Test", testNumber, "failed", n, expected) + } + } + + filter.Reset() + + testFunc(0, true) + testFunc(1, true) + testFunc(1, false) + testFunc(9, true) + testFunc(8, true) + testFunc(7, true) + testFunc(7, false) + testFunc(tLim, true) + testFunc(tLim-1, true) + testFunc(tLim-1, false) + testFunc(tLim-2, true) + testFunc(2, true) + testFunc(2, false) + testFunc(tLim+16, true) + testFunc(3, false) + testFunc(tLim+16, false) + testFunc(tLim*4, true) + testFunc(tLim*4-(tLim-1), true) + testFunc(10, false) + testFunc(tLim*4-tLim, false) + testFunc(tLim*4-(tLim+1), false) + testFunc(tLim*4-(tLim-2), true) + testFunc(tLim*4+1-tLim, false) + testFunc(0, false) + testFunc(RejectAfterMessages, false) + testFunc(RejectAfterMessages-1, true) + testFunc(RejectAfterMessages, false) + testFunc(RejectAfterMessages-1, false) + testFunc(RejectAfterMessages-2, true) + testFunc(RejectAfterMessages+1, false) + testFunc(RejectAfterMessages+2, false) + testFunc(RejectAfterMessages-2, false) + testFunc(RejectAfterMessages-3, true) + testFunc(0, false) + + t.Log("Bulk test 1") + filter.Reset() + testNumber = 0 + for i := uint64(1); i <= windowSize; i++ { + testFunc(i, true) + } + testFunc(0, true) + testFunc(0, false) + + t.Log("Bulk test 2") + filter.Reset() + testNumber = 0 + for i := uint64(2); i <= windowSize+1; i++ { + testFunc(i, true) + } + testFunc(1, true) + testFunc(0, false) + + t.Log("Bulk test 3") + filter.Reset() + testNumber = 0 + for i := uint64(windowSize + 1); i > 0; i-- { + testFunc(i, true) + } + + t.Log("Bulk test 4") + filter.Reset() + testNumber = 0 + for i := uint64(windowSize + 2); i > 1; i-- { + testFunc(i, true) + } + testFunc(0, false) + + t.Log("Bulk test 5") + filter.Reset() + testNumber = 0 + for i := uint64(windowSize); i > 0; i-- { + testFunc(i, true) + } + testFunc(windowSize+1, true) + testFunc(0, false) + + t.Log("Bulk test 6") + filter.Reset() + testNumber = 0 + for i := uint64(windowSize); i > 0; i-- { + testFunc(i, true) + } + testFunc(0, true) + testFunc(windowSize+1, true) +} diff --git a/pkg/api/middleware/replay_protection.go b/pkg/api/middleware/replay_protection.go new file mode 100644 index 000000000..d28146472 --- /dev/null +++ b/pkg/api/middleware/replay_protection.go @@ -0,0 +1,113 @@ +// Zaparoo Core +// Copyright (c) 2025 The Zaparoo Project Contributors. +// SPDX-License-Identifier: GPL-3.0-or-later +// +// This file is part of Zaparoo Core. +// +// Zaparoo Core is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Zaparoo Core is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Zaparoo Core. If not, see . + +package middleware + +import ( + "encoding/binary" + "slices" + + "github.com/ZaparooProject/zaparoo-core/v2/pkg/database" +) + +const ( + MaxSequenceNumber = 1<<64 - 1<<13 - 1 // Same as WireGuard's RejectAfterMessages + NonceCacheSize = 100 // Maximum number of cached nonces +) + +// ReplayProtector combines WireGuard's sequence number validation with nonce replay protection +type ReplayProtector struct { + nonceCache []string + filter Filter + lastSeq uint64 +} + +// NewReplayProtector creates a new replay protector from client state +func NewReplayProtector(client *database.Client) *ReplayProtector { + rp := &ReplayProtector{ + nonceCache: make([]string, len(client.NonceCache)), + lastSeq: client.CurrentSeq, + } + copy(rp.nonceCache, client.NonceCache) + + // Initialize filter from stored ring buffer state + if len(client.SeqWindow) >= 8+len(rp.filter.ring)*8 { + storedLastSeq := binary.LittleEndian.Uint64(client.SeqWindow[0:8]) + if storedLastSeq > 0 { + // Load stored state: first 8 bytes = last sequence, then 8 bytes per ring block + rp.filter.last = storedLastSeq + for i := range rp.filter.ring { + offset := 8 + i*8 + rp.filter.ring[i] = block(binary.LittleEndian.Uint64(client.SeqWindow[offset : offset+8])) + } + } else { + // Buffer exists but uninitialized - treat as new client + rp.filter.Reset() + if client.CurrentSeq > 0 { + rp.filter.ValidateCounter(client.CurrentSeq, MaxSequenceNumber) + } + } + } else { + // Buffer too small - initialize fresh + rp.filter.Reset() + if client.CurrentSeq > 0 { + rp.filter.ValidateCounter(client.CurrentSeq, MaxSequenceNumber) + } + } + + return rp +} + +// ValidateSequenceAndNonce validates both sequence number and nonce for replay protection +func (rp *ReplayProtector) ValidateSequenceAndNonce(seq uint64, nonce string) bool { + // Check nonce first (replay protection) + if slices.Contains(rp.nonceCache, nonce) { + return false + } + + // Then validate sequence number using WireGuard's algorithm + return rp.filter.ValidateCounter(seq, MaxSequenceNumber) +} + +// UpdateSequenceAndNonce updates the replay protector state after successful validation +func (rp *ReplayProtector) UpdateSequenceAndNonce(seq uint64, nonce string) { + // Update nonce cache (keep last NonceCacheSize nonces) + rp.nonceCache = append(rp.nonceCache, nonce) + if len(rp.nonceCache) > NonceCacheSize { + rp.nonceCache = rp.nonceCache[1:] // Remove oldest + } + + // Sequence is already updated by ValidateCounter if it was accepted + rp.lastSeq = max(rp.lastSeq, seq) +} + +// GetStateForDatabase returns the state that should be stored in the database +func (rp *ReplayProtector) GetStateForDatabase() (currentSeq uint64, seqWindow []byte, nonceCache []string) { + // Serialize ring buffer state + // Format: [8 bytes last seq][8 bytes per ring block] + seqWindow = make([]byte, 8+len(rp.filter.ring)*8) + binary.LittleEndian.PutUint64(seqWindow[0:8], rp.filter.last) + + for i, block := range rp.filter.ring[:] { + offset := 8 + i*8 + binary.LittleEndian.PutUint64(seqWindow[offset:offset+8], uint64(block)) + } + + return rp.filter.last, seqWindow, rp.nonceCache +} diff --git a/pkg/api/pairing.go b/pkg/api/pairing.go new file mode 100644 index 000000000..335c604aa --- /dev/null +++ b/pkg/api/pairing.go @@ -0,0 +1,355 @@ +// Zaparoo Core +// Copyright (c) 2025 The Zaparoo Project Contributors. +// SPDX-License-Identifier: GPL-3.0-or-later +// +// This file is part of Zaparoo Core. +// +// Zaparoo Core is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Zaparoo Core is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Zaparoo Core. If not, see . + +package api + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "net/http" + "regexp" + "runtime" + "strings" + "sync" + "time" + + "github.com/ZaparooProject/zaparoo-core/v2/pkg/database" + "github.com/google/uuid" + "github.com/rs/zerolog/log" + "golang.org/x/crypto/argon2" + "golang.org/x/crypto/hkdf" +) + +const ( + PairingTokenExpiry = 5 * time.Minute + PairingAttemptLimit = 10 +) + +type PairingSession struct { + CreatedAt time.Time + Token string + Challenge []byte + Salt []byte + Attempts int +} + +type PairingManager struct { + sessions map[string]*PairingSession + mu sync.RWMutex +} + +type PairingInitiateRequest struct { + ClientName string `json:"clientName"` +} + +type PairingInitiateResponse struct { + PairingToken string `json:"pairingToken"` + ExpiresIn int `json:"expiresIn"` +} + +type PairingCompleteRequest struct { + PairingToken string `json:"pairingToken"` + Verifier string `json:"verifier"` + ClientName string `json:"clientName"` +} + +type PairingCompleteResponse struct { + ClientID string `json:"clientId"` + AuthToken string `json:"authToken"` + SharedSecret string `json:"sharedSecret"` // Base64 encoded +} + +var pairingManager = &PairingManager{ + sessions: make(map[string]*PairingSession), +} + +var clientNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`) + +func validateClientName(name string) error { + name = strings.TrimSpace(name) + + if name == "" { + return errors.New("client name cannot be empty") + } + + if len(name) > 100 { + return errors.New("client name too long (max 100 characters)") + } + + if !clientNameRegex.MatchString(name) { + return errors.New("client name contains invalid characters (only letters, numbers, underscore, and dash)") + } + + return nil +} + +// StartPairingCleanup starts the pairing session cleanup routine +func StartPairingCleanup(ctx context.Context) { + go pairingManager.cleanup(ctx) +} + +func (pm *PairingManager) createSession() (*PairingSession, error) { + pm.mu.Lock() + defer pm.mu.Unlock() + + // Generate random token and challenge + token := uuid.New().String() + challenge := make([]byte, 32) + if _, err := rand.Read(challenge); err != nil { + return nil, fmt.Errorf("failed to generate challenge: %w", err) + } + + // Generate random salt (32 bytes for SHA-256) + salt := make([]byte, 32) + if _, err := rand.Read(salt); err != nil { + return nil, fmt.Errorf("failed to generate salt: %w", err) + } + + session := &PairingSession{ + Token: token, + Challenge: challenge, + Salt: salt, + CreatedAt: time.Now(), + Attempts: 0, + } + + pm.sessions[token] = session + + log.Debug().Str("token", token[:8]+"...").Msg("created pairing session") + return session, nil +} + +func (pm *PairingManager) consumeSession(token string) (*PairingSession, bool) { + pm.mu.Lock() + defer pm.mu.Unlock() + + session, exists := pm.sessions[token] + if !exists { + return nil, false + } + + // Check if expired + if time.Since(session.CreatedAt) > PairingTokenExpiry { + delete(pm.sessions, token) + return nil, false + } + + // Increment attempts + session.Attempts++ + + // Check attempt limit + if session.Attempts >= PairingAttemptLimit { + delete(pm.sessions, token) + return nil, false + } + + return session, true +} + +func (pm *PairingManager) cleanup(ctx context.Context) { + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + log.Debug().Msg("pairing manager cleanup routine stopped") + return + case <-ticker.C: + pm.mu.Lock() + now := time.Now() + for token, session := range pm.sessions { + if now.Sub(session.CreatedAt) > PairingTokenExpiry { + // Zero sensitive data before deletion + for i := range session.Challenge { + session.Challenge[i] = 0 + } + for i := range session.Salt { + session.Salt[i] = 0 + } + delete(pm.sessions, token) + } + } + pm.mu.Unlock() + } + } +} + +func handlePairingInitiate(_ *database.Database) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + session, err := pairingManager.createSession() + if err != nil { + log.Error().Err(err).Msg("failed to create pairing session") + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + response := PairingInitiateResponse{ + PairingToken: session.Token, + ExpiresIn: int(PairingTokenExpiry.Seconds()), + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(response); err != nil { + log.Error().Err(err).Msg("failed to encode pairing response") + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + log.Info().Str("token", session.Token[:8]+"...").Msg("pairing session initiated") + } +} + +func handlePairingComplete(db *database.Database) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var req PairingCompleteRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "Invalid request format", http.StatusBadRequest) + return + } + + if req.PairingToken == "" || req.Verifier == "" || req.ClientName == "" { + http.Error(w, "Missing required fields", http.StatusBadRequest) + return + } + + // Validate client name + if err := validateClientName(req.ClientName); err != nil { + log.Warn(). + Str("client_name", req.ClientName). + Str("remote_addr", r.RemoteAddr). + Err(err). + Msg("SECURITY: invalid client name provided") + http.Error(w, fmt.Sprintf("Invalid client name: %s", err.Error()), http.StatusBadRequest) + return + } + + // Get and consume pairing session + session, exists := pairingManager.consumeSession(req.PairingToken) + if !exists { + http.Error(w, "Invalid or expired pairing token", http.StatusBadRequest) + return + } + + // First, use Argon2id for strong key derivation from verifier + // Using security parameters recommended by OWASP (2023) + verifierKey := argon2.IDKey( + []byte(req.Verifier), // password + session.Salt, // salt (32 bytes) + 3, // time parameter (iterations) + 64*1024, // memory parameter (64 MB) + 4, // parallelism parameter + 32, // key length (32 bytes) + ) + // Ensure verifierKey is zeroed after use + defer func() { + for i := range verifierKey { + verifierKey[i] = 0 + } + runtime.KeepAlive(verifierKey) // Prevent compiler optimization + }() + + // Then combine with challenge using HKDF for domain separation + combinedSecret := make([]byte, len(session.Challenge)+len(verifierKey)) + copy(combinedSecret, session.Challenge) + copy(combinedSecret[len(session.Challenge):], verifierKey) + // Ensure combinedSecret is zeroed after use + defer func() { + for i := range combinedSecret { + combinedSecret[i] = 0 + } + runtime.KeepAlive(combinedSecret) // Prevent compiler optimization + }() + + sharedSecret := make([]byte, 32) // 256 bits for AES-256 + + // Construct context-specific info string for domain separation + info := []byte("zaparoo-pairing-v1|" + req.PairingToken + "|" + req.ClientName) + + kdf := hkdf.New(sha256.New, combinedSecret, session.Salt, info) + if _, err := kdf.Read(sharedSecret); err != nil { + log.Error().Err(err).Msg("failed to derive shared secret") + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + // Generate auth token + authToken := uuid.New().String() + + // Create client in database + client, err := db.UserDB.CreateClient(req.ClientName, authToken, sharedSecret) + if err != nil { + log.Error().Err(err).Msg("failed to create client") + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + // Zero sensitive session data before removing + for i := range session.Challenge { + session.Challenge[i] = 0 + } + for i := range session.Salt { + session.Salt[i] = 0 + } + + // Remove session from manager + pairingManager.mu.Lock() + delete(pairingManager.sessions, req.PairingToken) + pairingManager.mu.Unlock() + + response := PairingCompleteResponse{ + ClientID: client.ClientID, + AuthToken: authToken, + SharedSecret: hex.EncodeToString(sharedSecret), + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(response); err != nil { + log.Error().Err(err).Msg("failed to encode pairing response") + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + log.Info(). + Str("client_id", client.ClientID). + Str("client_name", client.ClientName). + Msg("client paired successfully") + } +} + +// QRCodeData represents the data embedded in the pairing QR code +type QRCodeData struct { + Address string `json:"address"` + Token string `json:"token"` +} diff --git a/pkg/api/server.go b/pkg/api/server.go index 4e0dc2f4d..86007649d 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -22,6 +22,7 @@ package api import ( "bytes" "context" + "crypto/rand" "encoding/json" "errors" "fmt" @@ -51,8 +52,10 @@ import ( "github.com/go-chi/chi/v5/middleware" "github.com/go-chi/cors" "github.com/google/uuid" + "github.com/gorilla/csrf" "github.com/olahol/melody" "github.com/rs/zerolog/log" + "github.com/unrolled/secure" ) var allowedOrigins = []string{ @@ -427,7 +430,7 @@ func buildDynamicAllowedOrigins(baseOrigins, localIPs []string, port int, custom } // broadcastNotifications consumes and broadcasts all incoming API -// notifications to all connected clients. +// notifications to all connected clients with appropriate encryption. func broadcastNotifications( st *state.State, session *melody.Melody, @@ -451,11 +454,61 @@ func broadcastNotifications( continue } - // TODO: this will not work with encryption - err = session.Broadcast(data) - if err != nil { - log.Error().Err(err).Msg("broadcasting notification") - } + // Broadcast to localhost sessions (unencrypted) + _ = session.BroadcastFilter(data, func(s *melody.Session) bool { + rawIP := strings.SplitN(s.Request.RemoteAddr, ":", 2) + clientIP := net.ParseIP(rawIP[0]) + return clientIP.IsLoopback() + }) + + // Broadcast to authenticated remote sessions (encrypted) + _ = session.BroadcastFilter(nil, func(s *melody.Session) bool { + // Check if this is a remote connection + rawIP := strings.SplitN(s.Request.RemoteAddr, ":", 2) + clientIP := net.ParseIP(rawIP[0]) + if clientIP.IsLoopback() { + return false // Skip localhost + } + + // Check if session is authenticated + client, authenticated := s.Get("client") + if !authenticated { + return false // Skip unauthenticated + } + + // Encrypt notification for this session + clientObj, ok := client.(*database.Client) + if !ok { + log.Error().Msg("invalid client type in session") + return false + } + encrypted, iv, err := apimiddleware.EncryptPayload(data, clientObj.SharedSecret) + if err != nil { + log.Error().Err(err).Str("client_id", clientObj.ClientID).Msg("failed to encrypt notification") + return false + } + + encResponse := apimiddleware.EncryptedRequest{ + Encrypted: encrypted, + IV: iv, + AuthToken: "", // Not needed for notifications + } + + encData, err := json.Marshal(encResponse) + if err != nil { + log.Error().Err(err).Str("client_id", clientObj.ClientID). + Msg("failed to marshal encrypted notification") + return false + } + + // Send encrypted data to this session + if err := s.Write(encData); err != nil { + log.Error().Err(err).Str("client_id", clientObj.ClientID). + Msg("failed to send encrypted notification") + } + + return false // Don't include in broadcast since we already sent manually + }) } } } @@ -517,6 +570,77 @@ func processRequestObject( // handleWSMessage parses all incoming WS requests, identifies what type of // JSON-RPC object they may be and forwards them to the appropriate function // to handle that type of message. +// handleWSPing handles ping messages for heartbeat operation +func handleWSPing(session *melody.Session, msg []byte) bool { + if bytes.Equal(msg, []byte("ping")) { + err := session.Write([]byte("pong")) + if err != nil { + log.Error().Err(err).Msg("sending pong") + } + return true + } + return false +} + +// handleWSAuthentication handles authentication for remote WebSocket connections +func handleWSRemoteAuth(session *melody.Session, msg []byte, db *database.Database) ([]byte, error) { + client, authenticated := session.Get("client") + if !authenticated { + // First message must be authentication + err := handleWSAuthentication(session, msg, db) + if err != nil { + log.Warn().Err(err). + Str("remote_addr", session.Request.RemoteAddr). + Str("user_agent", session.Request.Header.Get("User-Agent")). + Msg("SECURITY: WebSocket authentication failed") + _ = session.Close() + } + return nil, err + } + + // Check session timeout + if isSessionExpired(session) { + log.Warn(). + Str("remote_addr", session.Request.RemoteAddr). + Msg("SECURITY: WebSocket session expired - closing connection") + _ = session.Close() + return nil, errors.New("session expired") + } + + // Update last activity time + session.Set("auth_time", time.Now()) + + // Decrypt message for authenticated remote connection + clientObj, ok := client.(*database.Client) + if !ok { + log.Error().Msg("invalid client type in session") + return nil, errors.New("invalid client type") + } + + decryptedMsg, err := handleWSDecryption(session, msg, clientObj, db) + if err != nil { + log.Error().Err(err).Msg("WebSocket decryption failed") + return nil, err + } + + return decryptedMsg, nil +} + +// sendWSResponseByType sends response based on connection type (local/remote) +func sendWSResponseByType(session *melody.Session, isLocal bool, id uuid.UUID, resp any) error { + if !isLocal { + if client, authenticated := session.Get("client"); authenticated { + clientObj, ok := client.(*database.Client) + if !ok { + log.Error().Msg("invalid client type in session") + return errors.New("invalid client type") + } + return sendWSResponseEncrypted(session, id, resp, clientObj) + } + } + return sendWSResponse(session, id, resp) +} + func handleWSMessage( methodMap *MethodMap, platform platforms.Platform, @@ -536,24 +660,40 @@ func handleWSMessage( } }() - // ping command for heartbeat operation - if bytes.Equal(msg, []byte("ping")) { - err := session.Write([]byte("pong")) - if err != nil { - log.Error().Err(err).Msg("sending pong") - } + // Handle ping/pong heartbeat + if handleWSPing(session, msg) { return } + // Determine if connection is local rawIP := strings.SplitN(session.Request.RemoteAddr, ":", 2) clientIP := net.ParseIP(rawIP[0]) + isLocal := clientIP.IsLoopback() + + // Handle authentication and decryption for remote connections + if !isLocal { + decryptedMsg, err := handleWSRemoteAuth(session, msg, db) + if err != nil { + err := sendWSError(session, uuid.Nil, JSONRPCErrorInvalidRequest) + if err != nil { + log.Error().Err(err).Msg("error sending auth error response") + } + return + } + if decryptedMsg == nil { + return // Authentication in progress + } + msg = decryptedMsg + } + + // Process the request env := requests.RequestEnv{ Platform: platform, Config: cfg, State: st, Database: db, TokenQueue: inTokenQueue, - IsLocal: clientIP.IsLoopback(), + IsLocal: isLocal, } id, resp, rpcError := processRequestObject(methodMap, env, msg) @@ -563,7 +703,7 @@ func handleWSMessage( log.Error().Err(err).Msg("error sending error response") } } else { - err := sendWSResponse(session, id, resp) + err := sendWSResponseByType(session, isLocal, id, resp) if err != nil { log.Error().Err(err).Msg("error sending response") } @@ -571,6 +711,177 @@ func handleWSMessage( } } +type WSAuthMessage struct { + AuthToken string `json:"authToken"` +} + +type CSRFTokenResponse struct { + CSRFToken string `json:"csrfToken"` +} + +const sessionTimeout = 30 * time.Minute + +func isSessionExpired(session *melody.Session) bool { + authTime, exists := session.Get("auth_time") + if !exists { + return true // No auth time means expired + } + + timestamp, ok := authTime.(time.Time) + if !ok { + return true // Invalid timestamp means expired + } + + return time.Since(timestamp) > sessionTimeout +} + +func handleWSAuthentication(session *melody.Session, msg []byte, db *database.Database) error { + var authMsg WSAuthMessage + if err := json.Unmarshal(msg, &authMsg); err != nil { + return fmt.Errorf("invalid auth message format: %w", err) + } + + if authMsg.AuthToken == "" { + return errors.New("missing auth token") + } + + // Validate auth token and get client + client, err := db.UserDB.GetClientByAuthToken(authMsg.AuthToken) + if err != nil { + return fmt.Errorf("invalid auth token: %w", err) + } + + // Store client and auth timestamp in session + session.Set("client", client) + session.Set("auth_time", time.Now()) + + // Send authentication success response + authResponse := map[string]any{ + "authenticated": true, + "client_id": client.ClientID, + } + + responseData, _ := json.Marshal(authResponse) + err = session.Write(responseData) + if err != nil { + return fmt.Errorf("failed to send auth response: %w", err) + } + + log.Info(). + Str("client_id", client.ClientID). + Str("remote_addr", session.Request.RemoteAddr). + Msg("SECURITY: WebSocket authenticated successfully") + return nil +} + +func handleWSDecryption(_ *melody.Session, msg []byte, client *database.Client, db *database.Database) ([]byte, error) { + // Parse encrypted message + var encMsg apimiddleware.EncryptedRequest + if err := json.Unmarshal(msg, &encMsg); err != nil { + return nil, fmt.Errorf("invalid encrypted message format: %w", err) + } + + // Decrypt payload (we can reuse the middleware function by creating a temporary import) + decryptedPayload, err := apimiddleware.DecryptPayload(encMsg.Encrypted, encMsg.IV, client.SharedSecret) + if err != nil { + return nil, fmt.Errorf("decryption failed: %w", err) + } + + // Parse and validate sequence/nonce + var payload apimiddleware.DecryptedPayload + if unmarshalErr := json.Unmarshal(decryptedPayload, &payload); unmarshalErr != nil { + return nil, fmt.Errorf("invalid decrypted payload: %w", unmarshalErr) + } + + // Acquire client lock to prevent race conditions + // between validation and database update + unlockClient := apimiddleware.LockClient(client.ClientID) + defer unlockClient() + + // Re-fetch client state under lock to get latest sequence/nonce state + freshClient, err := db.UserDB.GetClientByID(client.ClientID) + if err != nil { + return nil, fmt.Errorf("failed to re-fetch client under lock: %w", err) + } + + // Create replay protector from fresh client state + replayProtector := apimiddleware.NewReplayProtector(freshClient) + + // Validate sequence and nonce with fresh client state + if !replayProtector.ValidateSequenceAndNonce(payload.Seq, payload.Nonce) { + return nil, errors.New("invalid sequence or replay detected") + } + + // Update replay protector state + replayProtector.UpdateSequenceAndNonce(payload.Seq, payload.Nonce) + + // Get updated state for database storage + currentSeq, seqWindow, nonceCache := replayProtector.GetStateForDatabase() + freshClient.CurrentSeq = currentSeq + freshClient.SeqWindow = seqWindow + freshClient.NonceCache = nonceCache + + if updateErr := db.UserDB.UpdateClientSequence( + freshClient.ClientID, freshClient.CurrentSeq, freshClient.SeqWindow, freshClient.NonceCache, + ); updateErr != nil { + return nil, fmt.Errorf("failed to update client state: %w", updateErr) + } + + // Return JSON-RPC payload without sequence/nonce + originalPayload := map[string]any{ + "jsonrpc": payload.JSONRPC, + "method": payload.Method, + "id": payload.ID, + } + if payload.Params != nil { + originalPayload["params"] = payload.Params + } + + result, err := json.Marshal(originalPayload) + if err != nil { + return nil, fmt.Errorf("failed to marshal payload: %w", err) + } + return result, nil +} + +func sendWSResponseEncrypted(session *melody.Session, id uuid.UUID, result any, client *database.Client) error { + // Create response object + resp := models.ResponseObject{ + JSONRPC: "2.0", + ID: id, + Result: result, + } + + // Marshal to JSON + data, err := json.Marshal(resp) + if err != nil { + return fmt.Errorf("error marshalling response: %w", err) + } + + // Encrypt the response + encrypted, iv, err := apimiddleware.EncryptPayload(data, client.SharedSecret) + if err != nil { + return fmt.Errorf("failed to encrypt response: %w", err) + } + + // Send encrypted response + encResponse := apimiddleware.EncryptedRequest{ + Encrypted: encrypted, + IV: iv, + AuthToken: "", // Not needed for responses + } + + encData, err := json.Marshal(encResponse) + if err != nil { + return fmt.Errorf("failed to marshal encrypted response: %w", err) + } + + if err := session.Write(encData); err != nil { + return fmt.Errorf("failed to send encrypted response: %w", err) + } + return nil +} + func handlePostRequest( methodMap *MethodMap, platform platforms.Platform, @@ -632,7 +943,7 @@ func handlePostRequest( } w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusInternalServerError) + w.WriteHeader(http.StatusOK) _, err = w.Write(respBody) if err != nil { log.Error().Err(err).Msg("failed to write error response") @@ -675,14 +986,39 @@ func Start( rateLimiter.StartCleanup(st.GetContext()) r.Use(apimiddleware.HTTPRateLimitMiddleware(rateLimiter)) + + // Start global mutex cleanup routine with proper context + apimiddleware.StartGlobalMutexCleanup(st.GetContext()) + + // Start pairing session cleanup routine with proper context + StartPairingCleanup(st.GetContext()) + + // Generate random CSRF key once at startup for security + csrfKey := make([]byte, 32) + if _, err := rand.Read(csrfKey); err != nil { + log.Fatal().Err(err).Msg("failed to generate CSRF key") + } + + // Security headers middleware - protect against common attacks + secureMiddleware := secure.New(secure.Options{ + FrameDeny: true, // X-Frame-Options: DENY + ContentTypeNosniff: true, // X-Content-Type-Options: nosniff + BrowserXssFilter: true, // X-XSS-Protection: 1; mode=block + ContentSecurityPolicy: "default-src 'self'; connect-src 'self' ws: wss:; " + + "img-src 'self' data:; style-src 'self' 'unsafe-inline'", + ReferrerPolicy: "strict-origin-when-cross-origin", + }) + r.Use(secureMiddleware.Handler) + r.Use(middleware.Recoverer) r.Use(middleware.NoCache) r.Use(middleware.Timeout(config.APIRequestTimeout)) r.Use(cors.Handler(cors.Options{ - AllowedOrigins: dynamicAllowedOrigins, - AllowedMethods: []string{"GET", "POST", "OPTIONS"}, - AllowedHeaders: []string{"Accept", "Content-Type"}, - ExposedHeaders: []string{}, + AllowedOrigins: dynamicAllowedOrigins, + AllowedMethods: []string{"GET", "POST", "OPTIONS"}, + AllowedHeaders: []string{"Accept", "Content-Type", "X-CSRF-Token"}, + ExposedHeaders: []string{"X-CSRF-Token"}, + AllowCredentials: true, // Required for CSRF protection })) if strings.HasSuffix(config.AppVersion, "-dev") { @@ -700,29 +1036,72 @@ func Start( } go broadcastNotifications(st, session, notifications) - r.Get("/api", func(w http.ResponseWriter, r *http.Request) { - err := session.HandleRequest(w, r) - if err != nil { - log.Error().Err(err).Msg("handling websocket request: latest") - } + // Pairing endpoints with CSRF protection and enhanced rate limiting + r.Route("/api/pair", func(r chi.Router) { + // CSRF protection for pairing endpoints + r.Use(csrf.Protect( + csrfKey, + csrf.Secure(false), // Set to true when using HTTPS + csrf.HttpOnly(true), + csrf.SameSite(csrf.SameSiteStrictMode), + )) + + // Enhanced rate limiting for pairing (more restrictive) + pairingLimiter := apimiddleware.NewIPRateLimiter() + pairingLimiter.StartCleanup(st.GetContext()) + r.Use(apimiddleware.HTTPRateLimitMiddleware(pairingLimiter)) + + r.Post("/initiate", handlePairingInitiate(db)) + r.Post("/complete", handlePairingComplete(db)) + + // CSRF token endpoint for clients + r.Get("/csrf-token", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + token := csrf.Token(r) + + response := CSRFTokenResponse{ + CSRFToken: token, + } + + if err := json.NewEncoder(w).Encode(response); err != nil { + log.Error().Err(err).Msg("Failed to write CSRF token response") + } + }) }) - r.Post("/api", handlePostRequest(methodMap, platform, cfg, st, inTokenQueue, db)) - r.Get("/api/v0", func(w http.ResponseWriter, r *http.Request) { - err := session.HandleRequest(w, r) - if err != nil { - log.Error().Err(err).Msg("handling websocket request: v0") - } + // Protected API routes with authentication middleware + r.Route("/api", func(r chi.Router) { + r.Use(apimiddleware.AuthMiddleware(db)) + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + err := session.HandleRequest(w, r) + if err != nil { + log.Error().Err(err).Msg("handling websocket request: latest") + } + }) + r.Post("/", handlePostRequest(methodMap, platform, cfg, st, inTokenQueue, db)) }) - r.Post("/api/v0", handlePostRequest(methodMap, platform, cfg, st, inTokenQueue, db)) - r.Get("/api/v0.1", func(w http.ResponseWriter, r *http.Request) { - err := session.HandleRequest(w, r) - if err != nil { - log.Error().Err(err).Msg("handling websocket request: v0.1") - } + r.Route("/api/v0", func(r chi.Router) { + r.Use(apimiddleware.AuthMiddleware(db)) + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + err := session.HandleRequest(w, r) + if err != nil { + log.Error().Err(err).Msg("handling websocket request: v0") + } + }) + r.Post("/", handlePostRequest(methodMap, platform, cfg, st, inTokenQueue, db)) + }) + + r.Route("/api/v0.1", func(r chi.Router) { + r.Use(apimiddleware.AuthMiddleware(db)) + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + err := session.HandleRequest(w, r) + if err != nil { + log.Error().Err(err).Msg("handling websocket request: v0.1") + } + }) + r.Post("/", handlePostRequest(methodMap, platform, cfg, st, inTokenQueue, db)) }) - r.Post("/api/v0.1", handlePostRequest(methodMap, platform, cfg, st, inTokenQueue, db)) session.HandleMessage(apimiddleware.WebSocketRateLimitHandler( rateLimiter, diff --git a/pkg/cli/cli.go b/pkg/cli/cli.go index 7fcdc2b13..0ee72bb80 100644 --- a/pkg/cli/cli.go +++ b/pkg/cli/cli.go @@ -40,16 +40,19 @@ import ( ) type Flags struct { - Write *string - Read *bool - Run *string - Launch *string - API *string - Version *bool - Config *bool - ShowLoader *string - ShowPicker *string - Reload *bool + Write *string + Read *bool + Run *string + Launch *string + API *string + Version *bool + Config *bool + ShowLoader *string + ShowPairingCode *bool + ListClients *bool + RevokeClient *string + ShowPicker *string + Reload *bool } // SetupFlags defines all common CLI flags between platforms. @@ -95,6 +98,21 @@ func SetupFlags() *Flags { false, "reload config and mappings from disk", ), + ShowPairingCode: flag.Bool( + "show-pairing-code", + false, + "display QR code for client pairing", + ), + ListClients: flag.Bool( + "list-clients", + false, + "list all paired clients", + ), + RevokeClient: flag.String( + "revoke-client", + "", + "revoke access for client by ID", + ), } } @@ -139,8 +157,17 @@ func runFlag(cfg *config.Instance, value string) { // Post actions all remaining common flags that require the environment to be // set up. Logging is allowed. -func (f *Flags) Post(cfg *config.Instance, _ platforms.Platform) { +func (f *Flags) Post(cfg *config.Instance, pl platforms.Platform) { switch { + case *f.ShowPairingCode: + handleShowPairingCode(cfg, pl) + os.Exit(0) + case *f.ListClients: + handleListClients(cfg, pl) + os.Exit(0) + case isFlagPassed("revoke-client"): + handleRevokeClient(cfg, pl, *f.RevokeClient) + os.Exit(0) case isFlagPassed("write"): if *f.Write == "" { _, _ = fmt.Fprint(os.Stderr, "Error: write flag requires a value\n") diff --git a/pkg/cli/clients.go b/pkg/cli/clients.go new file mode 100644 index 000000000..09251e97b --- /dev/null +++ b/pkg/cli/clients.go @@ -0,0 +1,213 @@ +// Zaparoo Core +// Copyright (c) 2025 The Zaparoo Project Contributors. +// SPDX-License-Identifier: GPL-3.0-or-later +// +// This file is part of Zaparoo Core. +// +// Zaparoo Core is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Zaparoo Core is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Zaparoo Core. If not, see . + +package cli + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "os" + "strings" + "time" + + "github.com/ZaparooProject/zaparoo-core/v2/pkg/api" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/config" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/database/userdb" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/platforms" + "github.com/rs/zerolog/log" +) + +func generateQRCode(data string) { + // Simple ASCII QR code placeholder - in a real implementation, you'd use a QR code library + if _, err := fmt.Print("\n=== PAIRING QR CODE ===\n"); err != nil { + log.Error().Err(err).Msg("failed to print header") + } + if _, err := fmt.Print("Scan this with your mobile app:\n\n"); err != nil { + log.Error().Err(err).Msg("failed to print instruction") + } + + // For now, just display the data - a real implementation would generate ASCII QR code + if _, err := fmt.Printf("Data: %s\n\n", data); err != nil { + log.Error().Err(err).Msg("failed to print data") + } + + // You could use a library like github.com/skip2/go-qrcode for actual QR generation + note := "Note: QR code display not yet implemented - use manual pairing with the data above\n" + if _, err := fmt.Print(note); err != nil { + log.Error().Err(err).Msg("failed to print note") + } + if _, err := fmt.Print("======================\n\n"); err != nil { + log.Error().Err(err).Msg("failed to print footer") + } +} + +func handleShowPairingCode(cfg *config.Instance, pl platforms.Platform) { + // Open user database to check pairing sessions + userDB, err := userdb.OpenUserDB(context.Background(), pl) + if err != nil { + log.Error().Err(err).Msg("failed to open user database") + if _, writeErr := fmt.Fprintf(os.Stderr, "Error opening user database: %v\n", err); writeErr != nil { + log.Error().Err(writeErr).Msg("failed to write error message") + } + os.Exit(1) + } + defer func() { _ = userDB.Close() }() + + // Create HTTP client to call our own pairing API + client := &http.Client{Timeout: 10 * time.Second} + + // Call pairing initiate endpoint + apiURL := fmt.Sprintf("http://localhost:%d/api/pair/initiate", cfg.APIPort()) + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, apiURL, strings.NewReader("{}")) + if err != nil { + log.Error().Err(err).Msg("failed to create pairing request") + if _, writeErr := fmt.Fprintf(os.Stderr, "Error creating request: %v\n", err); writeErr != nil { + log.Error().Err(writeErr).Msg("failed to write error message") + } + return + } + req.Header.Set("Content-Type", "application/json") + resp, err := client.Do(req) + if err != nil { + log.Error().Err(err).Msg("failed to initiate pairing") + if _, writeErr := fmt.Fprintf(os.Stderr, "Error initiating pairing: %v\n", err); writeErr != nil { + log.Error().Err(writeErr).Msg("failed to write error message") + } + return + } + defer func() { _ = resp.Body.Close() }() + + var pairingResp api.PairingInitiateResponse + if err := json.NewDecoder(resp.Body).Decode(&pairingResp); err != nil { + log.Error().Err(err).Msg("failed to decode pairing response") + if _, writeErr := fmt.Fprintf(os.Stderr, "Error decoding response: %v\n", err); writeErr != nil { + log.Error().Err(writeErr).Msg("failed to write error message") + } + return + } + + // Generate QR code data + serverAddr := fmt.Sprintf("localhost:%d", cfg.APIPort()) + qrData := api.QRCodeData{ + Address: serverAddr, + Token: pairingResp.PairingToken, + } + + jsonData, _ := json.Marshal(qrData) + generateQRCode(string(jsonData)) + + if _, err := fmt.Printf("Pairing token: %s\n", pairingResp.PairingToken); err != nil { + log.Error().Err(err).Msg("failed to print pairing token") + } + if _, err := fmt.Printf("Expires in: %d seconds\n", pairingResp.ExpiresIn); err != nil { + log.Error().Err(err).Msg("failed to print expiration time") + } + if _, err := fmt.Print("\nWaiting for client to pair... (Ctrl+C to cancel)\n"); err != nil { + log.Error().Err(err).Msg("failed to print waiting message") + } + + // Wait for user to cancel + select {} +} + +func handleListClients(_ *config.Instance, pl platforms.Platform) { + // Open user database to list clients + userDB, err := userdb.OpenUserDB(context.Background(), pl) + if err != nil { + log.Error().Err(err).Msg("failed to open user database") + if _, writeErr := fmt.Fprintf(os.Stderr, "Error opening user database: %v\n", err); writeErr != nil { + log.Error().Err(writeErr).Msg("failed to write error message") + } + os.Exit(1) + } + defer func() { _ = userDB.Close() }() + + clients, err := userDB.GetAllClients() + if err != nil { + log.Error().Err(err).Msg("failed to get clients") + if _, writeErr := fmt.Fprintf(os.Stderr, "Error getting clients: %v\n", err); writeErr != nil { + log.Error().Err(writeErr).Msg("failed to write error message") + } + return + } + + if len(clients) == 0 { + if _, err := fmt.Println("No paired clients found."); err != nil { + log.Error().Err(err).Msg("failed to print message") + } + return + } + + if _, err := fmt.Print("Paired clients:\n\n"); err != nil { + log.Error().Err(err).Msg("failed to print header") + } + if _, err := fmt.Printf("%-36s %-20s %-10s %s\n", "Client ID", "Name", "Sequence", "Last Seen"); err != nil { + log.Error().Err(err).Msg("failed to print column headers") + } + if _, err := fmt.Printf("%s\n", strings.Repeat("-", 80)); err != nil { + log.Error().Err(err).Msg("failed to print separator") + } + + for i := range clients { + client := &clients[i] + if _, err := fmt.Printf("%-36s %-20s %-10d %s\n", + client.ClientID, + client.ClientName, + client.CurrentSeq, + client.LastSeen.Format("2006-01-02 15:04:05"), + ); err != nil { + log.Error().Err(err).Msg("failed to print client info") + } + } +} + +func handleRevokeClient(_ *config.Instance, pl platforms.Platform, clientID string) { + if clientID == "" { + if _, err := fmt.Fprint(os.Stderr, "Error: client ID is required\n"); err != nil { + log.Error().Err(err).Msg("failed to write error message") + } + os.Exit(1) + } + + // Open user database to revoke client + userDB, err := userdb.OpenUserDB(context.Background(), pl) + if err != nil { + log.Error().Err(err).Msg("failed to open user database") + if _, writeErr := fmt.Fprintf(os.Stderr, "Error opening user database: %v\n", err); writeErr != nil { + log.Error().Err(writeErr).Msg("failed to write error message") + } + os.Exit(1) + } + defer func() { _ = userDB.Close() }() + + err = userDB.DeleteClient(clientID) + if err != nil { + log.Error().Err(err).Msg("failed to delete client") + if _, writeErr := fmt.Fprintf(os.Stderr, "Error deleting client: %v\n", err); writeErr != nil { + log.Error().Err(writeErr).Msg("failed to write error message") + } + return + } + + if _, err := fmt.Printf("Client %s has been revoked successfully.\n", clientID); err != nil { + log.Error().Err(err).Msg("failed to print success message") + } +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 8bd6ab319..b1baca1f2 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -359,6 +359,12 @@ func (c *Instance) AllowedOrigins() []string { return c.vals.Service.AllowedOrigins } +func (c *Instance) DeviceID() string { + c.mu.RLock() + defer c.mu.RUnlock() + return c.vals.Service.DeviceID +} + func (c *Instance) IsExecuteAllowed(s string) bool { c.mu.RLock() defer c.mu.RUnlock() diff --git a/pkg/database/database.go b/pkg/database/database.go index ec7961a9c..73f9a4a83 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -69,6 +69,18 @@ type System struct { DBID int64 } +type Client struct { + CreatedAt time.Time `json:"createdAt"` + LastSeen time.Time `json:"lastSeen"` + ClientID string `json:"clientId"` + ClientName string `json:"clientName"` + AuthTokenHash string `json:"-"` + SharedSecret []byte `json:"-"` + SeqWindow []byte `json:"-"` + NonceCache []string `json:"-"` + CurrentSeq uint64 `json:"currentSeq"` +} + type MediaTitle struct { Slug string Name string @@ -154,6 +166,12 @@ type UserDBI interface { GetZapLinkHost(host string) (bool, bool, error) UpdateZapLinkCache(url string, zapscript string) error GetZapLinkCache(url string) (string, error) + CreateClient(clientName, authToken string, sharedSecret []byte) (*Client, error) + GetClientByAuthToken(authToken string) (*Client, error) + GetClientByID(clientID string) (*Client, error) + UpdateClientSequence(clientID string, newSeq uint64, seqWindow []byte, nonceCache []string) error + GetAllClients() ([]Client, error) + DeleteClient(clientID string) error } type MediaDBI interface { diff --git a/pkg/database/userdb/clients.go b/pkg/database/userdb/clients.go new file mode 100644 index 000000000..a8c7ad50d --- /dev/null +++ b/pkg/database/userdb/clients.go @@ -0,0 +1,280 @@ +// Zaparoo Core +// Copyright (c) 2025 The Zaparoo Project Contributors. +// SPDX-License-Identifier: GPL-3.0-or-later +// +// This file is part of Zaparoo Core. +// +// Zaparoo Core is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Zaparoo Core is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Zaparoo Core. If not, see . + +package userdb + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "time" + + "github.com/ZaparooProject/zaparoo-core/v2/pkg/database" + "github.com/google/uuid" +) + +func (db *UserDB) CreateClient(clientName, authToken string, sharedSecret []byte) (*database.Client, error) { + if db.sql == nil { + return nil, ErrNullSQL + } + + clientID := uuid.New().String() + authTokenHash := hashAuthToken(authToken) + now := time.Now().Unix() + + client := &database.Client{ + ClientID: clientID, + ClientName: clientName, + AuthTokenHash: authTokenHash, + SharedSecret: sharedSecret, + CurrentSeq: 0, + SeqWindow: make([]byte, 8), // 64-bit window + NonceCache: make([]string, 0), + CreatedAt: time.Unix(now, 0), + LastSeen: time.Unix(now, 0), + } + + nonceCacheJSON, err := json.Marshal(client.NonceCache) + if err != nil { + return nil, fmt.Errorf("failed to marshal nonce cache: %w", err) + } + + query := ` + INSERT INTO clients (client_id, client_name, auth_token_hash, shared_secret, + current_seq, seq_window, nonce_cache, created_at, last_seen) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + ` + + _, err = db.sql.ExecContext(context.Background(), query, + client.ClientID, + client.ClientName, + client.AuthTokenHash, + client.SharedSecret, + client.CurrentSeq, + client.SeqWindow, + string(nonceCacheJSON), + now, + now, + ) + if err != nil { + return nil, fmt.Errorf("failed to create client: %w", err) + } + + return client, nil +} + +func (db *UserDB) GetClientByAuthToken(authToken string) (*database.Client, error) { + if db.sql == nil { + return nil, ErrNullSQL + } + + authTokenHash := hashAuthToken(authToken) + + query := ` + SELECT client_id, client_name, auth_token_hash, shared_secret, current_seq, + seq_window, nonce_cache, created_at, last_seen + FROM clients + WHERE auth_token_hash = ? + ` + + var client database.Client + var nonceCacheJSON string + var createdAt, lastSeen int64 + + err := db.sql.QueryRowContext(context.Background(), query, authTokenHash).Scan( + &client.ClientID, + &client.ClientName, + &client.AuthTokenHash, + &client.SharedSecret, + &client.CurrentSeq, + &client.SeqWindow, + &nonceCacheJSON, + &createdAt, + &lastSeen, + ) + if err != nil { + // Use a generic error to avoid leaking information about whether the token exists + return nil, errors.New("invalid credentials") + } + + client.CreatedAt = time.Unix(createdAt, 0) + client.LastSeen = time.Unix(lastSeen, 0) + + if unmarshalErr := json.Unmarshal([]byte(nonceCacheJSON), &client.NonceCache); unmarshalErr != nil { + client.NonceCache = make([]string, 0) // Fallback to empty cache + } + + return &client, nil +} + +func (db *UserDB) GetClientByID(clientID string) (*database.Client, error) { + if db.sql == nil { + return nil, ErrNullSQL + } + + query := ` + SELECT client_id, client_name, auth_token_hash, shared_secret, current_seq, + seq_window, nonce_cache, created_at, last_seen + FROM clients + WHERE client_id = ? + ` + + var client database.Client + var nonceCacheJSON string + var createdAt, lastSeen int64 + + err := db.sql.QueryRowContext(context.Background(), query, clientID).Scan( + &client.ClientID, + &client.ClientName, + &client.AuthTokenHash, + &client.SharedSecret, + &client.CurrentSeq, + &client.SeqWindow, + &nonceCacheJSON, + &createdAt, + &lastSeen, + ) + if err != nil { + return nil, fmt.Errorf("client not found: %w", err) + } + + client.CreatedAt = time.Unix(createdAt, 0) + client.LastSeen = time.Unix(lastSeen, 0) + + err = json.Unmarshal([]byte(nonceCacheJSON), &client.NonceCache) + if err != nil { + client.NonceCache = make([]string, 0) // Fallback to empty cache + } + + return &client, nil +} + +func (db *UserDB) UpdateClientSequence(clientID string, newSeq uint64, seqWindow []byte, nonceCache []string) error { + if db.sql == nil { + return ErrNullSQL + } + + nonceCacheJSON, err := json.Marshal(nonceCache) + if err != nil { + return fmt.Errorf("failed to marshal nonce cache: %w", err) + } + + query := ` + UPDATE clients + SET current_seq = ?, seq_window = ?, nonce_cache = ?, last_seen = ? + WHERE client_id = ? + ` + + _, err = db.sql.ExecContext( + context.Background(), query, newSeq, seqWindow, string(nonceCacheJSON), time.Now().Unix(), clientID, + ) + if err != nil { + return fmt.Errorf("failed to update client sequence: %w", err) + } + + return nil +} + +func (db *UserDB) GetAllClients() ([]database.Client, error) { + if db.sql == nil { + return nil, ErrNullSQL + } + + query := ` + SELECT client_id, client_name, auth_token_hash, shared_secret, current_seq, + seq_window, nonce_cache, created_at, last_seen + FROM clients + ORDER BY last_seen DESC + ` + + rows, err := db.sql.QueryContext(context.Background(), query) + if err != nil { + return nil, fmt.Errorf("failed to query clients: %w", err) + } + defer func() { _ = rows.Close() }() + + clients := make([]database.Client, 0) + for rows.Next() { + var client database.Client + var nonceCacheJSON string + var createdAt, lastSeen int64 + + scanErr := rows.Scan( + &client.ClientID, + &client.ClientName, + &client.AuthTokenHash, + &client.SharedSecret, + &client.CurrentSeq, + &client.SeqWindow, + &nonceCacheJSON, + &createdAt, + &lastSeen, + ) + if scanErr != nil { + return nil, fmt.Errorf("failed to scan client row: %w", scanErr) + } + + client.CreatedAt = time.Unix(createdAt, 0) + client.LastSeen = time.Unix(lastSeen, 0) + + err = json.Unmarshal([]byte(nonceCacheJSON), &client.NonceCache) + if err != nil { + client.NonceCache = make([]string, 0) // Fallback to empty cache + } + + clients = append(clients, client) + } + + if err = rows.Err(); err != nil { + return nil, fmt.Errorf("error reading client rows: %w", err) + } + + return clients, nil +} + +func (db *UserDB) DeleteClient(clientID string) error { + if db.sql == nil { + return ErrNullSQL + } + + query := `DELETE FROM clients WHERE client_id = ?` + result, err := db.sql.ExecContext(context.Background(), query, clientID) + if err != nil { + return fmt.Errorf("failed to delete client: %w", err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } + + if rowsAffected == 0 { + return errors.New("client not found") + } + + return nil +} + +func hashAuthToken(authToken string) string { + hash := sha256.Sum256([]byte(authToken)) + return hex.EncodeToString(hash[:]) +} diff --git a/pkg/database/userdb/migrations/20250912231204_clients_auth.sql b/pkg/database/userdb/migrations/20250912231204_clients_auth.sql new file mode 100644 index 000000000..b3380c6d7 --- /dev/null +++ b/pkg/database/userdb/migrations/20250912231204_clients_auth.sql @@ -0,0 +1,22 @@ +-- +goose Up +-- +goose StatementBegin +CREATE TABLE clients ( + client_id TEXT PRIMARY KEY, + client_name TEXT NOT NULL, + auth_token_hash TEXT NOT NULL UNIQUE, + shared_secret BLOB NOT NULL, + current_seq INTEGER DEFAULT 0, + seq_window BLOB, + nonce_cache TEXT DEFAULT '[]', + created_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now')), + last_seen INTEGER NOT NULL DEFAULT (strftime('%s', 'now')) +); + +CREATE INDEX idx_clients_auth_token ON clients(auth_token_hash); +CREATE INDEX idx_clients_last_seen ON clients(last_seen); +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +DROP TABLE clients; +-- +goose StatementEnd \ No newline at end of file diff --git a/pkg/testing/helpers/db_mocks.go b/pkg/testing/helpers/db_mocks.go index b95b4e178..2607e8d64 100644 --- a/pkg/testing/helpers/db_mocks.go +++ b/pkg/testing/helpers/db_mocks.go @@ -57,6 +57,11 @@ import ( "github.com/stretchr/testify/mock" ) +// Sentinel errors for mock functions +var ( + ErrMockNotConfigured = errors.New("mock not configured for this call") +) + // MockUserDBI is a mock implementation of the UserDBI interface using testify/mock type MockUserDBI struct { mock.Mock @@ -239,6 +244,81 @@ func (m *MockUserDBI) GetZapLinkCache(url string) (string, error) { return args.String(0), args.Error(1) } +// Client authentication methods +func (m *MockUserDBI) CreateClient(clientName, authToken string, sharedSecret []byte) (*database.Client, error) { + args := m.Called(clientName, authToken, sharedSecret) + if client, ok := args.Get(0).(*database.Client); ok { + if err := args.Error(1); err != nil { + return client, fmt.Errorf("mock UserDBI create client failed: %w", err) + } + return client, nil + } + if err := args.Error(1); err != nil { + return nil, fmt.Errorf("mock UserDBI create client failed: %w", err) + } + return nil, ErrMockNotConfigured +} + +func (m *MockUserDBI) GetClientByAuthToken(authToken string) (*database.Client, error) { + args := m.Called(authToken) + if client, ok := args.Get(0).(*database.Client); ok { + if err := args.Error(1); err != nil { + return client, fmt.Errorf("mock UserDBI get client by auth token failed: %w", err) + } + return client, nil + } + if err := args.Error(1); err != nil { + return nil, fmt.Errorf("mock UserDBI get client by auth token failed: %w", err) + } + return nil, ErrMockNotConfigured +} + +func (m *MockUserDBI) GetClientByID(clientID string) (*database.Client, error) { + args := m.Called(clientID) + if client, ok := args.Get(0).(*database.Client); ok { + if err := args.Error(1); err != nil { + return client, fmt.Errorf("mock UserDBI get client by ID failed: %w", err) + } + return client, nil + } + if err := args.Error(1); err != nil { + return nil, fmt.Errorf("mock UserDBI get client by ID failed: %w", err) + } + return nil, ErrMockNotConfigured +} + +func (m *MockUserDBI) UpdateClientSequence( + clientID string, newSeq uint64, seqWindow []byte, nonceCache []string, +) error { + args := m.Called(clientID, newSeq, seqWindow, nonceCache) + if err := args.Error(0); err != nil { + return fmt.Errorf("mock UserDBI update client sequence failed: %w", err) + } + return nil +} + +func (m *MockUserDBI) GetAllClients() ([]database.Client, error) { + args := m.Called() + if clients, ok := args.Get(0).([]database.Client); ok { + if err := args.Error(1); err != nil { + return clients, fmt.Errorf("mock UserDBI get all clients failed: %w", err) + } + return clients, nil + } + if err := args.Error(1); err != nil { + return nil, fmt.Errorf("mock UserDBI get all clients failed: %w", err) + } + return nil, nil +} + +func (m *MockUserDBI) DeleteClient(clientID string) error { + args := m.Called(clientID) + if err := args.Error(0); err != nil { + return fmt.Errorf("mock UserDBI delete client failed: %w", err) + } + return nil +} + // MockMediaDBI is a mock implementation of the MediaDBI interface using testify/mock type MockMediaDBI struct { mock.Mock