diff --git a/docs/api/authentication.md b/docs/api/authentication.md
new file mode 100644
index 000000000..35b580536
--- /dev/null
+++ b/docs/api/authentication.md
@@ -0,0 +1,618 @@
+# Zaparoo Core Secure API Authentication
+
+This document provides a complete guide for developers implementing clients that connect to Zaparoo Core using the secure authentication layer.
+
+## Table of Contents
+
+- [Overview](#overview)
+- [Security Model](#security-model)
+- [Client Pairing](#client-pairing)
+- [HTTP API Usage](#http-api-usage)
+- [WebSocket API Usage](#websocket-api-usage)
+- [Example Implementations](#example-implementations)
+- [Error Handling](#error-handling)
+- [Security Best Practices](#security-best-practices)
+
+## Overview
+
+Zaparoo Core implements a secure authentication system for remote client connections while maintaining backward compatibility for localhost connections. The security model uses:
+
+- **AES-256-GCM** encryption for all remote API communications
+- **Argon2id + HKDF** key derivation for client pairing
+- **Sequence numbers + nonces** for replay attack protection
+- **Per-client state management** for concurrent access safety
+
+### Connection Types
+
+| Connection Type | Authentication Required | Encryption Required |
+| -------------------------- | ----------------------- | ------------------- |
+| Localhost (127.0.0.1, ::1) | ❌ No | ❌ No |
+| Remote (all other IPs) | ✅ Yes | ✅ Yes |
+
+## Security Model
+
+### Encryption Format
+
+All remote API requests must use the following encrypted format:
+
+```json
+{
+ "encrypted": "base64-encoded-ciphertext",
+ "iv": "base64-encoded-initialization-vector",
+ "authToken": "client-auth-token"
+}
+```
+
+### Decrypted Payload Format
+
+The decrypted payload contains the standard JSON-RPC request plus security fields:
+
+```json
+{
+ "jsonrpc": "2.0",
+ "method": "system.version",
+ "id": 1,
+ "params": {},
+ "seq": 123,
+ "nonce": "unique-request-nonce"
+}
+```
+
+## Client Pairing
+
+Before making authenticated requests, clients must complete a pairing process to obtain shared encryption keys.
+
+### Step 1: Initiate Pairing
+
+**Request:**
+
+```http
+POST /api/pair/initiate
+Content-Type: application/json
+
+{
+ "clientName": "MyApp_v1.0"
+}
+```
+
+**Response:**
+
+```json
+{
+ "pairingToken": "550e8400-e29b-41d4-a716-446655440000",
+ "expiresIn": 300
+}
+```
+
+### Step 2: Complete Pairing
+
+The client must obtain a verification code through an out-of-band method (QR code, manual entry, etc.) and complete pairing.
+
+**Request:**
+
+```http
+POST /api/pair/complete
+Content-Type: application/json
+
+{
+ "pairingToken": "550e8400-e29b-41d4-a716-446655440000",
+ "verifier": "user-provided-verification-code",
+ "clientName": "MyApp_v1.0"
+}
+```
+
+**Response:**
+
+```json
+{
+ "clientId": "client-uuid-here",
+ "authToken": "auth-token-uuid-here",
+ "sharedSecret": "64-character-hex-encoded-key"
+}
+```
+
+### Client Name Requirements
+
+- Length: 1-100 characters
+- Characters: letters, numbers, underscore, dash only (`^[a-zA-Z0-9_-]+$`)
+- Must be unique per client instance
+
+## HTTP API Usage
+
+### Making Authenticated Requests
+
+1. **Prepare the JSON-RPC payload** with sequence number and nonce
+2. **Encrypt the payload** using AES-256-GCM
+3. **Send the encrypted request** to the API endpoint
+
+### Example: Get System Version
+
+```javascript
+// 1. Prepare JSON-RPC payload
+const payload = {
+ jsonrpc: "2.0",
+ method: "system.version",
+ id: 1,
+ seq: ++sequenceNumber,
+ nonce: generateNonce(),
+};
+
+// 2. Encrypt payload
+const encrypted = await encryptPayload(JSON.stringify(payload), sharedSecret);
+
+// 3. Send request
+const response = await fetch("http://zaparoo-host:7497/api/v0.1", {
+ method: "POST",
+ headers: { "Content-Type": "application/json" },
+ body: JSON.stringify({
+ encrypted: encrypted.ciphertext,
+ iv: encrypted.iv,
+ authToken: authToken,
+ }),
+});
+
+// 4. Decrypt response
+const encryptedResponse = await response.json();
+const decryptedResponse = await decryptPayload(
+ encryptedResponse.encrypted,
+ encryptedResponse.iv,
+ sharedSecret
+);
+```
+
+## WebSocket API Usage
+
+WebSocket connections require a two-step authentication process:
+
+### Step 1: Connect and Authenticate
+
+```javascript
+const ws = new WebSocket("ws://zaparoo-host:7497/api/v0.1");
+
+ws.onopen = () => {
+ // Send authentication message
+ ws.send(
+ JSON.stringify({
+ authToken: authToken,
+ })
+ );
+};
+
+ws.onmessage = (event) => {
+ const message = JSON.parse(event.data);
+
+ if (message.authenticated) {
+ console.log("WebSocket authenticated successfully");
+ // Now ready to send encrypted requests
+ } else {
+ // Handle encrypted response
+ handleEncryptedMessage(message);
+ }
+};
+```
+
+### Step 2: Send Encrypted Requests
+
+```javascript
+async function sendRequest(method, params = {}) {
+ const payload = {
+ jsonrpc: "2.0",
+ method: method,
+ id: generateId(),
+ params: params,
+ seq: ++sequenceNumber,
+ nonce: generateNonce(),
+ };
+
+ const encrypted = await encryptPayload(JSON.stringify(payload), sharedSecret);
+
+ ws.send(
+ JSON.stringify({
+ encrypted: encrypted.ciphertext,
+ iv: encrypted.iv,
+ authToken: authToken,
+ })
+ );
+}
+
+// Example usage
+await sendRequest("system.version");
+await sendRequest("search.games", { query: "mario" });
+```
+
+### Handling Encrypted Responses
+
+```javascript
+async function handleEncryptedMessage(encryptedMessage) {
+ if (encryptedMessage.encrypted && encryptedMessage.iv) {
+ const decrypted = await decryptPayload(
+ encryptedMessage.encrypted,
+ encryptedMessage.iv,
+ sharedSecret
+ );
+
+ const response = JSON.parse(decrypted);
+ console.log("Received response:", response);
+ }
+}
+```
+
+## Example Implementations
+
+### JavaScript/Node.js Client
+
+```javascript
+import crypto from "crypto";
+
+class ZaparooSecureClient {
+ constructor(host, port = 7497) {
+ this.baseUrl = `http://${host}:${port}`;
+ this.wsUrl = `ws://${host}:${port}`;
+ this.sequenceNumber = 0;
+ this.authToken = null;
+ this.sharedSecret = null;
+ }
+
+ // Pairing process
+ async pair(clientName, verifier) {
+ // Step 1: Initiate pairing
+ const initResponse = await fetch(`${this.baseUrl}/api/pair/initiate`, {
+ method: "POST",
+ headers: { "Content-Type": "application/json" },
+ body: JSON.stringify({ clientName }),
+ });
+
+ const { pairingToken } = await initResponse.json();
+
+ // Step 2: Complete pairing (verifier obtained out-of-band)
+ const completeResponse = await fetch(`${this.baseUrl}/api/pair/complete`, {
+ method: "POST",
+ headers: { "Content-Type": "application/json" },
+ body: JSON.stringify({
+ pairingToken,
+ verifier,
+ clientName,
+ }),
+ });
+
+ const result = await completeResponse.json();
+ this.authToken = result.authToken;
+ this.sharedSecret = Buffer.from(result.sharedSecret, "hex");
+
+ return result;
+ }
+
+ // Encryption utilities
+ async encryptPayload(data) {
+ const iv = crypto.randomBytes(12);
+ const cipher = crypto.createCipherGCM("aes-256-gcm", this.sharedSecret);
+ cipher.setAAD(Buffer.alloc(0));
+
+ const encrypted = Buffer.concat([
+ cipher.update(data, "utf8"),
+ cipher.final(),
+ ]);
+
+ const authTag = cipher.getAuthTag();
+ const ciphertext = Buffer.concat([encrypted, authTag]);
+
+ return {
+ ciphertext: ciphertext.toString("base64"),
+ iv: iv.toString("base64"),
+ };
+ }
+
+ async decryptPayload(encryptedB64, ivB64) {
+ const encrypted = Buffer.from(encryptedB64, "base64");
+ const iv = Buffer.from(ivB64, "base64");
+
+ const authTag = encrypted.slice(-16);
+ const ciphertext = encrypted.slice(0, -16);
+
+ const decipher = crypto.createDecipherGCM("aes-256-gcm", this.sharedSecret);
+ decipher.setAuthTag(authTag);
+ decipher.setAAD(Buffer.alloc(0));
+
+ const decrypted = Buffer.concat([
+ decipher.update(ciphertext),
+ decipher.final(),
+ ]);
+
+ return decrypted.toString("utf8");
+ }
+
+ generateNonce() {
+ return crypto.randomBytes(16).toString("hex");
+ }
+
+ // HTTP API request
+ async request(method, params = {}) {
+ const payload = {
+ jsonrpc: "2.0",
+ method,
+ id: Date.now(),
+ params,
+ seq: ++this.sequenceNumber,
+ nonce: this.generateNonce(),
+ };
+
+ const encrypted = await this.encryptPayload(JSON.stringify(payload));
+
+ const response = await fetch(`${this.baseUrl}/api/v0.1`, {
+ method: "POST",
+ headers: { "Content-Type": "application/json" },
+ body: JSON.stringify({
+ encrypted: encrypted.ciphertext,
+ iv: encrypted.iv,
+ authToken: this.authToken,
+ }),
+ });
+
+ const encryptedResponse = await response.json();
+ const decryptedData = await this.decryptPayload(
+ encryptedResponse.encrypted,
+ encryptedResponse.iv
+ );
+
+ return JSON.parse(decryptedData);
+ }
+}
+
+// Usage example
+const client = new ZaparooSecureClient("192.168.1.100");
+
+// Pair with server (verifier obtained from QR code/user input)
+await client.pair("MyApp_v1.0", "verification-code-from-qr");
+
+// Make API requests
+const version = await client.request("system.version");
+const games = await client.request("search.games", { query: "zelda" });
+```
+
+### Python Client
+
+```python
+import asyncio
+import json
+import secrets
+from cryptography.hazmat.primitives.ciphers.aead import AESGCM
+import aiohttp
+import websockets
+
+class ZaparooSecureClient:
+ def __init__(self, host, port=7497):
+ self.base_url = f"http://{host}:{port}"
+ self.ws_url = f"ws://{host}:{port}"
+ self.sequence_number = 0
+ self.auth_token = None
+ self.shared_secret = None
+
+ async def pair(self, client_name, verifier):
+ async with aiohttp.ClientSession() as session:
+ # Initiate pairing
+ async with session.post(
+ f"{self.base_url}/api/pair/initiate",
+ json={"clientName": client_name}
+ ) as response:
+ result = await response.json()
+ pairing_token = result["pairingToken"]
+
+ # Complete pairing
+ async with session.post(
+ f"{self.base_url}/api/pair/complete",
+ json={
+ "pairingToken": pairing_token,
+ "verifier": verifier,
+ "clientName": client_name
+ }
+ ) as response:
+ result = await response.json()
+ self.auth_token = result["authToken"]
+ self.shared_secret = bytes.fromhex(result["sharedSecret"])
+ return result
+
+ def encrypt_payload(self, data):
+ aesgcm = AESGCM(self.shared_secret)
+ nonce = secrets.token_bytes(12)
+ ciphertext = aesgcm.encrypt(nonce, data.encode(), None)
+
+ return {
+ "ciphertext": ciphertext.hex(),
+ "iv": nonce.hex()
+ }
+
+ def decrypt_payload(self, encrypted_hex, iv_hex):
+ aesgcm = AESGCM(self.shared_secret)
+ ciphertext = bytes.fromhex(encrypted_hex)
+ nonce = bytes.fromhex(iv_hex)
+
+ plaintext = aesgcm.decrypt(nonce, ciphertext, None)
+ return plaintext.decode()
+
+ def generate_nonce(self):
+ return secrets.token_hex(16)
+
+ async def request(self, method, params=None):
+ if params is None:
+ params = {}
+
+ payload = {
+ "jsonrpc": "2.0",
+ "method": method,
+ "id": 1,
+ "params": params,
+ "seq": self.sequence_number + 1,
+ "nonce": self.generate_nonce()
+ }
+ self.sequence_number += 1
+
+ encrypted = self.encrypt_payload(json.dumps(payload))
+
+ async with aiohttp.ClientSession() as session:
+ async with session.post(
+ f"{self.base_url}/api/v0.1",
+ json={
+ "encrypted": encrypted["ciphertext"],
+ "iv": encrypted["iv"],
+ "authToken": self.auth_token
+ }
+ ) as response:
+ encrypted_response = await response.json()
+ decrypted = self.decrypt_payload(
+ encrypted_response["encrypted"],
+ encrypted_response["iv"]
+ )
+ return json.loads(decrypted)
+
+# Usage
+client = ZaparooSecureClient("192.168.1.100")
+await client.pair("MyPythonApp", "verification-code")
+result = await client.request("system.version")
+```
+
+## Error Handling
+
+### Common Error Responses
+
+| Error | HTTP Status | Description |
+| ---------------------- | ----------- | -------------------------------- |
+| Invalid auth token | 401 | Token not found or expired |
+| Invalid request format | 400 | Malformed JSON or missing fields |
+| Decryption failed | 400 | Invalid encryption or wrong key |
+| Invalid sequence | 400 | Replay attack detected |
+| Method not allowed | 405 | Wrong HTTP method |
+
+### Example Error Response
+
+```json
+{
+ "jsonrpc": "2.0",
+ "id": 1,
+ "error": {
+ "code": -32600,
+ "message": "Invalid Request"
+ }
+}
+```
+
+### Client-Side Error Handling
+
+```javascript
+try {
+ const result = await client.request("system.version");
+ console.log(result);
+} catch (error) {
+ if (error.response?.status === 401) {
+ console.log("Authentication failed - re-pair required");
+ await client.pair(clientName, newVerifier);
+ } else if (error.response?.status === 400) {
+ console.log("Request error:", error.response.data);
+ } else {
+ console.log("Network error:", error.message);
+ }
+}
+```
+
+## Security Best Practices
+
+### For Client Developers
+
+1. **Secure Key Storage**: Store shared secrets securely (keychain, encrypted storage)
+2. **Sequence Management**: Persist sequence numbers across app restarts
+3. **Nonce Uniqueness**: Generate cryptographically random nonces
+4. **Error Handling**: Don't expose sensitive information in logs
+5. **Network Security**: Use TLS for additional transport security
+6. **Token Rotation**: Implement re-pairing for long-running applications
+
+### Example Secure Storage (Node.js)
+
+```javascript
+import keytar from "keytar";
+
+class SecureStorage {
+ static async storeCredentials(
+ clientId,
+ authToken,
+ sharedSecret,
+ sequenceNumber
+ ) {
+ await keytar.setPassword("zaparoo", `${clientId}-auth`, authToken);
+ await keytar.setPassword("zaparoo", `${clientId}-secret`, sharedSecret);
+ await keytar.setPassword(
+ "zaparoo",
+ `${clientId}-seq`,
+ sequenceNumber.toString()
+ );
+ }
+
+ static async loadCredentials(clientId) {
+ const authToken = await keytar.getPassword("zaparoo", `${clientId}-auth`);
+ const sharedSecret = await keytar.getPassword(
+ "zaparoo",
+ `${clientId}-secret`
+ );
+ const sequenceNumber = parseInt(
+ (await keytar.getPassword("zaparoo", `${clientId}-seq`)) || "0"
+ );
+
+ return { authToken, sharedSecret, sequenceNumber };
+ }
+}
+```
+
+### Sequence Number Management
+
+```javascript
+class SequenceManager {
+ constructor(clientId) {
+ this.clientId = clientId;
+ this.sequenceNumber = this.loadSequenceNumber();
+ }
+
+ getNext() {
+ this.sequenceNumber++;
+ this.saveSequenceNumber();
+ return this.sequenceNumber;
+ }
+
+ loadSequenceNumber() {
+ // Load from persistent storage
+ return parseInt(
+ localStorage.getItem(`zaparoo-seq-${this.clientId}`) || "0"
+ );
+ }
+
+ saveSequenceNumber() {
+ localStorage.setItem(
+ `zaparoo-seq-${this.clientId}`,
+ this.sequenceNumber.toString()
+ );
+ }
+}
+```
+
+## API Endpoints Reference
+
+### Available Endpoints
+
+| Endpoint | Authentication | Description |
+| ------------------------- | ----------------- | ------------------------ |
+| `POST /api/pair/initiate` | None | Start pairing process |
+| `POST /api/pair/complete` | None | Complete pairing process |
+| `POST /api/v0.1` | Required (remote) | JSON-RPC API |
+| `WS /api/v0.1` | Required (remote) | WebSocket API |
+
+### Supported Methods
+
+Once authenticated, all standard Zaparoo Core JSON-RPC methods are available:
+
+- `system.version` - Get system version
+- `system.heartbeat` - Health check
+- `search.games` - Search game database
+- `launch.game` - Launch a game
+- `media.scan` - Trigger media scan
+- And more... (see main API documentation)
+
+---
+
+**Note**: This secure authentication layer is only required for remote connections. Localhost connections (127.0.0.1, ::1) continue to work without authentication for development and local tool access.
diff --git a/go.mod b/go.mod
index 9a19ffedb..4d1dedcb7 100755
--- a/go.mod
+++ b/go.mod
@@ -36,6 +36,7 @@ require (
go.bug.st/serial v1.6.4
go.etcd.io/bbolt v1.4.0
golang.design/x/clipboard v0.7.0
+ golang.org/x/crypto v0.38.0
golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6
golang.org/x/sys v0.35.0
golang.org/x/text v0.26.0
@@ -54,6 +55,8 @@ require (
github.com/geoffgarside/ber v1.1.0 // indirect
github.com/go-toast/toast v0.0.0-20190211030409-01e6764cf0a4 // indirect
github.com/godbus/dbus/v5 v5.1.0 // indirect
+ github.com/gorilla/csrf v1.7.3 // indirect
+ github.com/gorilla/securecookie v1.1.2 // indirect
github.com/hashicorp/go-uuid v1.0.3 // indirect
github.com/jcmturner/aescts/v2 v2.0.0 // indirect
github.com/jcmturner/dnsutils/v2 v2.0.0 // indirect
@@ -73,8 +76,8 @@ require (
github.com/sethvargo/go-retry v0.3.0 // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/tadvi/systray v0.0.0-20190226123456-11a2b8fa57af // indirect
+ github.com/unrolled/secure v1.17.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
- golang.org/x/crypto v0.38.0 // indirect
golang.org/x/exp/shiny v0.0.0-20250408133849-7e4ce0ab07d0 // indirect
golang.org/x/image v0.20.0 // indirect
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a // indirect
diff --git a/go.sum b/go.sum
index c2abad4e7..bbe7e1339 100644
--- a/go.sum
+++ b/go.sum
@@ -59,8 +59,12 @@ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
+github.com/gorilla/csrf v1.7.3 h1:BHWt6FTLZAb2HtWT5KDBf6qgpZzvtbp9QWDRKZMXJC0=
+github.com/gorilla/csrf v1.7.3/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk=
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
+github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA=
+github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI=
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
@@ -143,6 +147,8 @@ github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOf
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tadvi/systray v0.0.0-20190226123456-11a2b8fa57af h1:6yITBqGTE2lEeTPG04SN9W+iWHCRyHqlVYILiSXziwk=
github.com/tadvi/systray v0.0.0-20190226123456-11a2b8fa57af/go.mod h1:4F09kP5F+am0jAwlQLddpoMDM+iewkxxt6nxUQ5nq5o=
+github.com/unrolled/secure v1.17.0 h1:Io7ifFgo99Bnh0J7+Q+qcMzWM6kaDPCA5FroFZEdbWU=
+github.com/unrolled/secure v1.17.0/go.mod h1:BmF5hyM6tXczk3MpQkFf1hpKSRqCyhqcbiQtiAF7+40=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
go.bug.st/serial v1.6.4 h1:7FmqNPgVp3pu2Jz5PoPtbZ9jJO5gnEnZIvnI1lzve8A=
go.bug.st/serial v1.6.4/go.mod h1:nofMJxTeNVny/m6+KaafC6vJGj3miwQZ6vW4BZUGJPI=
diff --git a/pkg/api/middleware/auth.go b/pkg/api/middleware/auth.go
new file mode 100644
index 000000000..bcafa3b39
--- /dev/null
+++ b/pkg/api/middleware/auth.go
@@ -0,0 +1,399 @@
+// Zaparoo Core
+// Copyright (c) 2025 The Zaparoo Project Contributors.
+// SPDX-License-Identifier: GPL-3.0-or-later
+//
+// This file is part of Zaparoo Core.
+//
+// Zaparoo Core is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// Zaparoo Core is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with Zaparoo Core. If not, see .
+
+package middleware
+
+import (
+ "bytes"
+ "context"
+ "crypto/aes"
+ "crypto/cipher"
+ "crypto/rand"
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "sync"
+ "time"
+
+ "github.com/ZaparooProject/zaparoo-core/v2/pkg/database"
+ "github.com/rs/zerolog/log"
+)
+
+const (
+ MutexCleanupInterval = 10 * time.Minute // Cleanup unused mutexes every 10 minutes
+ MutexMaxIdle = 30 * time.Minute // Remove mutexes unused for 30 minutes
+)
+
+type clientKey string
+
+// ClientMutexManager handles per-client locking to prevent race conditions
+// in authentication state updates
+type ClientMutexManager struct {
+ mutexes sync.Map // map[string]*clientMutex
+}
+
+// ClientMutex represents a per-client mutex for thread-safe operations
+type ClientMutex struct {
+ lastUsed time.Time
+ clientID string
+ mu sync.Mutex
+}
+
+// Legacy alias for backward compatibility
+type clientMutex = ClientMutex
+
+var globalClientMutexManager = &ClientMutexManager{}
+
+type EncryptedRequest struct {
+ Encrypted string `json:"encrypted"`
+ IV string `json:"iv"`
+ AuthToken string `json:"authToken"`
+}
+
+type DecryptedPayload struct {
+ ID any `json:"id,omitempty"`
+ JSONRPC string `json:"jsonrpc"`
+ Method string `json:"method"`
+ Nonce string `json:"nonce"`
+ Params json.RawMessage `json:"params,omitempty"`
+ Seq uint64 `json:"seq"`
+}
+
+func isLocalhost(remoteAddr string) bool {
+ host, _, err := net.SplitHostPort(remoteAddr)
+ if err != nil {
+ host = remoteAddr
+ }
+
+ ip := net.ParseIP(host)
+ if ip == nil {
+ return host == "localhost"
+ }
+
+ return ip.IsLoopback()
+}
+
+func AuthMiddleware(db *database.Database) func(http.Handler) http.Handler {
+ return func(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ // Skip authentication for localhost connections
+ if isLocalhost(r.RemoteAddr) {
+ log.Debug().Str("remote_addr", r.RemoteAddr).Msg("localhost connection - skipping auth")
+ next.ServeHTTP(w, r)
+ return
+ }
+
+ // Read request body
+ body, err := io.ReadAll(r.Body)
+ if err != nil {
+ log.Error().Err(err).Msg("failed to read request body")
+ http.Error(w, "Failed to read request body", http.StatusInternalServerError)
+ return
+ }
+
+ // Try to parse as encrypted request
+ var encReq EncryptedRequest
+ if parseErr := json.Unmarshal(body, &encReq); parseErr != nil {
+ log.Error().Err(parseErr).Msg("invalid encrypted request format")
+ http.Error(w, "Invalid request format", http.StatusBadRequest)
+ return
+ }
+
+ // First, validate auth token and get client for initial validation
+ initialClient, err := db.UserDB.GetClientByAuthToken(encReq.AuthToken)
+ if err != nil {
+ tokenStr := "empty"
+ if len(encReq.AuthToken) >= 8 {
+ tokenStr = encReq.AuthToken[:8] + "..."
+ } else if encReq.AuthToken != "" {
+ tokenStr = encReq.AuthToken
+ }
+ log.Warn().Err(err).
+ Str("token", tokenStr).
+ Str("remote_addr", r.RemoteAddr).
+ Str("user_agent", r.Header.Get("User-Agent")).
+ Msg("SECURITY: invalid auth token - potential attack")
+ http.Error(w, "Invalid auth token", http.StatusUnauthorized)
+ return
+ }
+
+ // Acquire client lock BEFORE any decryption or validation
+ // to prevent race conditions between concurrent requests
+ unlockClient := LockClient(initialClient.ClientID)
+ defer unlockClient()
+
+ // Re-fetch client state under lock to get latest sequence/nonce state
+ client, err := db.UserDB.GetClientByAuthToken(encReq.AuthToken)
+ if err != nil {
+ log.Error().Err(err).Msg("failed to re-fetch client under lock")
+ http.Error(w, "Internal server error", http.StatusInternalServerError)
+ return
+ }
+
+ // Decrypt payload with locked client data
+ decryptedPayload, err := DecryptPayload(encReq.Encrypted, encReq.IV, client.SharedSecret)
+ if err != nil {
+ log.Error().Err(err).Msg("failed to decrypt payload")
+ http.Error(w, "Decryption failed", http.StatusBadRequest)
+ return
+ }
+
+ // Parse decrypted payload
+ var payload DecryptedPayload
+ if parseErr := json.Unmarshal(decryptedPayload, &payload); parseErr != nil {
+ log.Error().Err(parseErr).Msg("invalid decrypted payload format")
+ http.Error(w, "Invalid payload format", http.StatusBadRequest)
+ return
+ }
+
+ // Create replay protector from client state
+ replayProtector := NewReplayProtector(client)
+
+ // Validate sequence number and nonce with locked client state
+ if !replayProtector.ValidateSequenceAndNonce(payload.Seq, payload.Nonce) {
+ log.Warn().
+ Str("client_id", client.ClientID).
+ Uint64("seq", payload.Seq).
+ Str("nonce", payload.Nonce).
+ Str("remote_addr", r.RemoteAddr).
+ Str("user_agent", r.Header.Get("User-Agent")).
+ Msg("SECURITY: replay attack detected")
+ http.Error(w, "Invalid sequence or replay detected", http.StatusBadRequest)
+ return
+ }
+
+ // Update replay protector state
+ replayProtector.UpdateSequenceAndNonce(payload.Seq, payload.Nonce)
+
+ // Get updated state for database storage
+ currentSeq, seqWindow, nonceCache := replayProtector.GetStateForDatabase()
+ client.CurrentSeq = currentSeq
+ client.SeqWindow = seqWindow
+ client.NonceCache = nonceCache
+
+ // Save to database (still under lock)
+ if updateErr := db.UserDB.UpdateClientSequence(
+ client.ClientID, client.CurrentSeq, client.SeqWindow, client.NonceCache,
+ ); updateErr != nil {
+ log.Error().Err(updateErr).Msg("failed to update client state")
+ http.Error(w, "Internal server error", http.StatusInternalServerError)
+ return
+ }
+
+ // Replace request body with decrypted JSON-RPC payload
+ originalPayload := map[string]any{
+ "jsonrpc": payload.JSONRPC,
+ "method": payload.Method,
+ "id": payload.ID,
+ }
+ if payload.Params != nil {
+ originalPayload["params"] = payload.Params
+ }
+
+ newBody, err := json.Marshal(originalPayload)
+ if err != nil {
+ log.Error().Err(err).Msg("failed to marshal decrypted payload")
+ http.Error(w, "Internal server error", http.StatusInternalServerError)
+ return
+ }
+
+ // Create new request with decrypted body
+ r.Body = io.NopCloser(bytes.NewReader(newBody))
+ r.ContentLength = int64(len(newBody))
+
+ // Store client in context for potential use by handlers
+ ctx := context.WithValue(r.Context(), clientKey("client"), client)
+ r = r.WithContext(ctx)
+
+ log.Info().
+ Str("client_id", client.ClientID).
+ Str("method", payload.Method).
+ Uint64("seq", payload.Seq).
+ Str("remote_addr", r.RemoteAddr).
+ Msg("SECURITY: authenticated request processed")
+
+ next.ServeHTTP(w, r)
+ })
+ }
+}
+
+func DecryptPayload(encryptedB64, ivB64 string, key []byte) ([]byte, error) {
+ // Decode base64
+ encrypted, err := base64.StdEncoding.DecodeString(encryptedB64)
+ if err != nil {
+ return nil, fmt.Errorf("invalid encrypted data: %w", err)
+ }
+
+ iv, err := base64.StdEncoding.DecodeString(ivB64)
+ if err != nil {
+ return nil, fmt.Errorf("invalid IV: %w", err)
+ }
+
+ // Create AES cipher
+ block, err := aes.NewCipher(key)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create cipher: %w", err)
+ }
+
+ // Create GCM mode
+ gcm, err := cipher.NewGCM(block)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create GCM: %w", err)
+ }
+
+ // Decrypt
+ plaintext, err := gcm.Open(nil, iv, encrypted, nil)
+ if err != nil {
+ return nil, fmt.Errorf("decryption failed: %w", err)
+ }
+
+ return plaintext, nil
+}
+
+func EncryptPayload(data, key []byte) (encrypted, iv string, err error) {
+ // Create AES cipher
+ block, err := aes.NewCipher(key)
+ if err != nil {
+ return "", "", fmt.Errorf("failed to create cipher: %w", err)
+ }
+
+ // Create GCM mode
+ gcm, err := cipher.NewGCM(block)
+ if err != nil {
+ return "", "", fmt.Errorf("failed to create GCM: %w", err)
+ }
+
+ // Generate random IV
+ ivBytes := make([]byte, gcm.NonceSize())
+ if _, err := io.ReadFull(rand.Reader, ivBytes); err != nil {
+ return "", "", fmt.Errorf("failed to generate IV: %w", err)
+ }
+
+ // Encrypt
+ ciphertext := gcm.Seal(nil, ivBytes, data, nil)
+
+ // Return base64 encoded values
+ return base64.StdEncoding.EncodeToString(ciphertext),
+ base64.StdEncoding.EncodeToString(ivBytes), nil
+}
+
+// getClientMutex retrieves or creates a mutex for the given client ID
+func (cm *ClientMutexManager) getClientMutex(clientID string) *ClientMutex {
+ // Try to load existing mutex
+ if value, exists := cm.mutexes.Load(clientID); exists {
+ mutex, ok := value.(*clientMutex)
+ if !ok {
+ log.Error().Str("client_id", clientID).Msg("invalid mutex type in cache")
+ return nil
+ }
+ mutex.lastUsed = time.Now()
+ return mutex
+ }
+
+ // Create new mutex
+ newMutex := &ClientMutex{
+ lastUsed: time.Now(),
+ clientID: clientID,
+ }
+
+ // Store and return the mutex (LoadOrStore handles race conditions)
+ actual, _ := cm.mutexes.LoadOrStore(clientID, newMutex)
+ actualMutex, ok := actual.(*clientMutex)
+ if !ok {
+ log.Error().Str("client_id", clientID).Msg("invalid mutex type after LoadOrStore")
+ return nil
+ }
+ actualMutex.lastUsed = time.Now()
+ return actualMutex
+}
+
+// lockClient acquires a lock for the specified client, preventing race conditions
+// in authentication state updates. The returned function must be called to release the lock.
+func (cm *ClientMutexManager) lockClient(clientID string) func() {
+ mutex := cm.getClientMutex(clientID)
+ mutex.mu.Lock()
+
+ return func() {
+ mutex.mu.Unlock()
+ }
+}
+
+// cleanup removes unused mutexes to prevent memory leaks
+func (cm *ClientMutexManager) cleanup() {
+ now := time.Now()
+ cm.mutexes.Range(func(key, value any) bool {
+ mutex, ok := value.(*clientMutex)
+ if !ok {
+ log.Error().Interface("key", key).Msg("invalid mutex type in cleanup")
+ return true
+ }
+ if now.Sub(mutex.lastUsed) > MutexMaxIdle {
+ cm.mutexes.Delete(key)
+ log.Debug().Str("client_id", mutex.clientID).Msg("cleaned up unused client mutex")
+ }
+ return true
+ })
+}
+
+// StartCleanupRoutine starts a background goroutine to periodically clean up unused mutexes
+func (cm *ClientMutexManager) StartCleanupRoutine(ctx context.Context) {
+ go func() {
+ ticker := time.NewTicker(MutexCleanupInterval)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ctx.Done():
+ log.Debug().Msg("client mutex cleanup routine stopped")
+ return
+ case <-ticker.C:
+ cm.cleanup()
+ }
+ }
+ }()
+}
+
+// GetClientMutex is a convenience function to get a mutex for a client
+func GetClientMutex(clientID string) *ClientMutex {
+ return globalClientMutexManager.getClientMutex(clientID)
+}
+
+// LockClient is a convenience function to lock a client
+func LockClient(clientID string) func() {
+ return globalClientMutexManager.lockClient(clientID)
+}
+
+// StartGlobalMutexCleanup starts the global mutex cleanup routine
+func StartGlobalMutexCleanup(ctx context.Context) {
+ globalClientMutexManager.StartCleanupRoutine(ctx)
+}
+
+func IsAuthenticatedConnection(r *http.Request) bool {
+ return !isLocalhost(r.RemoteAddr)
+}
+
+func GetClientFromContext(ctx context.Context) *database.Client {
+ if client, ok := ctx.Value(clientKey("client")).(*database.Client); ok {
+ return client
+ }
+ return nil
+}
diff --git a/pkg/api/middleware/auth_test.go b/pkg/api/middleware/auth_test.go
new file mode 100644
index 000000000..96325d943
--- /dev/null
+++ b/pkg/api/middleware/auth_test.go
@@ -0,0 +1,638 @@
+// Zaparoo Core
+// Copyright (c) 2025 The Zaparoo Project Contributors.
+// SPDX-License-Identifier: GPL-3.0-or-later
+//
+// This file is part of Zaparoo Core.
+//
+// Zaparoo Core is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// Zaparoo Core is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with Zaparoo Core. If not, see .
+
+package middleware
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/ZaparooProject/zaparoo-core/v2/pkg/database"
+ "github.com/ZaparooProject/zaparoo-core/v2/pkg/testing/helpers"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/mock"
+ "github.com/stretchr/testify/require"
+)
+
+//nolint:paralleltest,tparallel // Security tests require deterministic mock validation
+func TestAuthMiddleware_LocalhostBypass(t *testing.T) {
+ t.Parallel()
+ // Setup
+ userDB := helpers.NewMockUserDBI()
+ db := &database.Database{UserDB: userDB}
+
+ handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ _, _ = w.Write([]byte("success"))
+ })
+
+ middleware := AuthMiddleware(db)
+ wrappedHandler := middleware(handler)
+
+ tests := []struct {
+ name string
+ remoteAddr string
+ expectPass bool
+ }{
+ {"localhost with port", "127.0.0.1:12345", true},
+ {"localhost without port", "127.0.0.1", true},
+ {"localhost name with port", "localhost:8080", true},
+ {"localhost name without port", "localhost", true},
+ {"IPv6 loopback", "[::1]:8080", true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodPost, "/api/test", bytes.NewReader([]byte(`{"test": "data"}`)))
+ req.RemoteAddr = tt.remoteAddr
+
+ w := httptest.NewRecorder()
+ wrappedHandler.ServeHTTP(w, req)
+
+ assert.Equal(t, http.StatusOK, w.Code, "localhost should bypass auth")
+ assert.Equal(t, "success", w.Body.String())
+ })
+ }
+}
+
+//nolint:paralleltest,tparallel // Security tests require deterministic mock validation
+func TestAuthMiddleware_RemoteRequiresAuth(t *testing.T) {
+ t.Parallel()
+ // Setup
+ userDB := helpers.NewMockUserDBI()
+ // Mock the GetClientByAuthToken call to return an error for empty token
+ // The middleware calls this once per test case (we have 2 test cases)
+ userDB.On("GetClientByAuthToken", "").Return((*database.Client)(nil), assert.AnError).Times(2)
+
+ db := &database.Database{UserDB: userDB}
+
+ handler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
+ t.Error("handler should not be called for unauthenticated remote requests")
+ })
+
+ middleware := AuthMiddleware(db)
+ wrappedHandler := middleware(handler)
+
+ tests := []struct {
+ name string
+ remoteAddr string
+ }{
+ {"private network IP", "192.168.1.100:5000"},
+ {"public IP", "8.8.8.8:80"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Send regular JSON without proper auth fields - should fail at auth token lookup
+ req := httptest.NewRequest(http.MethodPost, "/api/test", bytes.NewReader([]byte(`{"test": "data"}`)))
+ req.RemoteAddr = tt.remoteAddr
+
+ w := httptest.NewRecorder()
+ wrappedHandler.ServeHTTP(w, req)
+
+ // Should fail because auth token is empty/invalid
+ assert.Equal(t, http.StatusUnauthorized, w.Code, "remote should fail auth")
+ assert.Contains(t, w.Body.String(), "Invalid auth token")
+ })
+ }
+
+ userDB.AssertExpectations(t)
+}
+
+func TestAuthMiddleware_EncryptedRequest(t *testing.T) {
+ t.Parallel()
+ // Create a mock device with known shared secret
+ testSecret := []byte("test-secret-key-32-bytes-long-ok")
+ testDevice := &database.Client{
+ ClientID: "test-device-id",
+ ClientName: "Test Device",
+ AuthTokenHash: "test-token-hash",
+ SharedSecret: testSecret,
+ CurrentSeq: 0,
+ SeqWindow: make([]byte, 8),
+ NonceCache: []string{},
+ CreatedAt: time.Now(),
+ LastSeen: time.Now(),
+ }
+
+ userDB := helpers.NewMockUserDBI()
+ userDB.On("GetClientByAuthToken", "test-auth-token").Return(testDevice, nil)
+ userDB.On("UpdateClientSequence", "test-device-id", uint64(1),
+ mock.AnythingOfType("[]uint8"), mock.AnythingOfType("[]string")).Return(nil)
+
+ db := &database.Database{UserDB: userDB}
+
+ // Create encrypted payload
+ payload := DecryptedPayload{
+ JSONRPC: "2.0",
+ Method: "test.method",
+ ID: 1,
+ Seq: 1,
+ Nonce: "test-nonce-123",
+ }
+
+ payloadJSON, err := json.Marshal(payload)
+ require.NoError(t, err)
+
+ encryptedData, iv, err := EncryptPayload(payloadJSON, testSecret)
+ require.NoError(t, err)
+
+ encRequest := EncryptedRequest{
+ Encrypted: encryptedData,
+ IV: iv,
+ AuthToken: "test-auth-token",
+ }
+
+ encRequestJSON, err := json.Marshal(encRequest)
+ require.NoError(t, err)
+
+ // Test handler that checks the decrypted request
+ handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ // Verify the request body was decrypted properly
+ var jsonRPC map[string]any
+ err := json.NewDecoder(r.Body).Decode(&jsonRPC)
+ assert.NoError(t, err)
+
+ assert.Equal(t, "2.0", jsonRPC["jsonrpc"])
+ assert.Equal(t, "test.method", jsonRPC["method"])
+ assert.InDelta(t, float64(1), jsonRPC["id"], 0.001) // JSON unmarshals numbers as float64
+
+ // Verify device is in context
+ device := GetClientFromContext(r.Context())
+ assert.NotNil(t, device)
+ assert.Equal(t, "test-device-id", device.ClientID)
+
+ w.WriteHeader(http.StatusOK)
+ _, _ = w.Write([]byte("authenticated success"))
+ })
+
+ middleware := AuthMiddleware(db)
+ wrappedHandler := middleware(handler)
+
+ // Test successful authentication
+ req := httptest.NewRequest(http.MethodPost, "/api/test", bytes.NewReader(encRequestJSON))
+ req.RemoteAddr = "192.168.1.100:5000" // Remote address to trigger auth
+
+ w := httptest.NewRecorder()
+ wrappedHandler.ServeHTTP(w, req)
+
+ assert.Equal(t, http.StatusOK, w.Code)
+ assert.Equal(t, "authenticated success", w.Body.String())
+ userDB.AssertExpectations(t)
+}
+
+//nolint:paralleltest,tparallel // Security tests require deterministic mock validation
+func TestReplayProtector_ReplayProtection(t *testing.T) {
+ t.Parallel()
+ tests := []struct {
+ name string
+ newNonce string
+ description string
+ seqWindow []byte
+ nonceCache []string
+ currentSeq uint64
+ newSeq uint64
+ expectedResult bool
+ }{
+ {
+ name: "first message",
+ currentSeq: 0,
+ seqWindow: make([]byte, 8+128*8), // Ring buffer: 8 bytes + 128 blocks * 8 bytes
+ nonceCache: []string{},
+ newSeq: 1,
+ newNonce: "nonce1",
+ expectedResult: true,
+ description: "first message should always pass",
+ },
+ {
+ name: "sequence increment",
+ currentSeq: 5,
+ seqWindow: make([]byte, 8+128*8),
+ nonceCache: []string{"old-nonce"},
+ newSeq: 6,
+ newNonce: "nonce6",
+ expectedResult: true,
+ description: "incrementing sequence should pass",
+ },
+ {
+ name: "duplicate nonce",
+ currentSeq: 5,
+ seqWindow: make([]byte, 8+128*8),
+ nonceCache: []string{"duplicate-nonce"},
+ newSeq: 6,
+ newNonce: "duplicate-nonce",
+ expectedResult: false,
+ description: "duplicate nonce should be rejected",
+ },
+ {
+ name: "old sequence far out of window",
+ currentSeq: 50000,
+ seqWindow: make([]byte, 8+128*8),
+ nonceCache: []string{},
+ newSeq: 100, // More than 8000+ behind (outside WireGuard window)
+ newNonce: "nonce100",
+ expectedResult: false,
+ description: "sequence too far behind should be rejected",
+ },
+ {
+ name: "sequence within large window",
+ currentSeq: 1000,
+ seqWindow: make([]byte, 8+128*8),
+ nonceCache: []string{},
+ newSeq: 950, // Within large window
+ newNonce: "nonce950",
+ expectedResult: true,
+ description: "sequence within sliding window should pass",
+ },
+ {
+ name: "large sequence jump forward",
+ currentSeq: 5,
+ seqWindow: make([]byte, 8+128*8),
+ nonceCache: []string{},
+ newSeq: 1000, // Large jump forward
+ newNonce: "nonce1000",
+ expectedResult: true,
+ description: "large sequence jump should be accepted",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ client := &database.Client{
+ ClientID: "test-device",
+ CurrentSeq: tt.currentSeq,
+ SeqWindow: tt.seqWindow,
+ NonceCache: tt.nonceCache,
+ }
+
+ replayProtector := NewReplayProtector(client)
+ result := replayProtector.ValidateSequenceAndNonce(tt.newSeq, tt.newNonce)
+ assert.Equal(t, tt.expectedResult, result, tt.description)
+ })
+ }
+}
+
+func TestEncryptDecryptPayload(t *testing.T) {
+ t.Parallel()
+ testKey := []byte("test-encryption-key-32bytes-ok!!")
+ originalData := []byte(`{"jsonrpc":"2.0","method":"test","id":1}`)
+
+ // Test encryption
+ encrypted, iv, err := EncryptPayload(originalData, testKey)
+ require.NoError(t, err)
+ assert.NotEmpty(t, encrypted)
+ assert.NotEmpty(t, iv)
+
+ // Test decryption
+ decrypted, err := DecryptPayload(encrypted, iv, testKey)
+ require.NoError(t, err)
+ assert.Equal(t, originalData, decrypted)
+}
+
+func TestEncryptDecryptPayload_WrongKey(t *testing.T) {
+ t.Parallel()
+ correctKey := []byte("correct-key-32-bytes-long-ok!!!!")
+ wrongKey := []byte("wrong-key-32-bytes-long-ok!!!!!!")
+ originalData := []byte(`{"test": "data"}`)
+
+ // Encrypt with correct key
+ encrypted, iv, err := EncryptPayload(originalData, correctKey)
+ require.NoError(t, err)
+
+ // Try to decrypt with wrong key - should fail
+ _, err = DecryptPayload(encrypted, iv, wrongKey)
+ require.Error(t, err, "decryption should fail with wrong key")
+ assert.Contains(t, err.Error(), "decryption failed")
+}
+
+//nolint:paralleltest,tparallel // Security tests require deterministic execution order
+func TestIsLocalhost(t *testing.T) {
+ t.Parallel()
+ tests := []struct {
+ addr string
+ expected bool
+ }{
+ {"127.0.0.1:8080", true},
+ {"127.0.0.1", true},
+ {"localhost:3000", true},
+ {"localhost", true},
+ {"[::1]:8080", true},
+ {"::1", true},
+ {"192.168.1.1:8080", false},
+ {"8.8.8.8:53", false},
+ {"example.com:80", false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.addr, func(t *testing.T) {
+ result := isLocalhost(tt.addr)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+//nolint:paralleltest,tparallel // Security tests require deterministic mock validation
+func TestAuthMiddleware_InvalidRequests(t *testing.T) {
+ t.Parallel()
+ userDB := helpers.NewMockUserDBI()
+ // Mock empty auth token lookup - expect it to be called once for the missing auth token test
+ userDB.On("GetClientByAuthToken", "").Return((*database.Client)(nil), assert.AnError).Once()
+
+ db := &database.Database{UserDB: userDB}
+
+ handler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
+ t.Error("handler should not be called for invalid requests")
+ })
+
+ middleware := AuthMiddleware(db)
+ wrappedHandler := middleware(handler)
+
+ tests := []struct {
+ name string
+ body string
+ description string
+ expectedCode int
+ }{
+ {
+ name: "invalid json",
+ body: `{"invalid json`,
+ expectedCode: http.StatusBadRequest,
+ description: "malformed JSON should be rejected",
+ },
+ {
+ name: "missing auth token",
+ body: `{"encrypted": "data", "iv": "iv"}`,
+ expectedCode: http.StatusUnauthorized,
+ description: "missing auth token should be rejected",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodPost, "/api/test", bytes.NewReader([]byte(tt.body)))
+ req.RemoteAddr = "192.168.1.100:5000" // Remote address to trigger auth
+
+ w := httptest.NewRecorder()
+ wrappedHandler.ServeHTTP(w, req)
+
+ assert.Equal(t, tt.expectedCode, w.Code, tt.description)
+ })
+ }
+
+ userDB.AssertExpectations(t)
+}
+
+func TestGetClientFromContext(t *testing.T) {
+ t.Parallel()
+ // Test with client in context
+ client := &database.Client{ClientID: "test-device"}
+ ctx := context.WithValue(context.Background(), clientKey("client"), client)
+
+ result := GetClientFromContext(ctx)
+ assert.Equal(t, client, result)
+
+ // Test with no client in context
+ emptyCtx := context.Background()
+ result = GetClientFromContext(emptyCtx)
+ assert.Nil(t, result)
+
+ // Test with wrong type in context
+ badCtx := context.WithValue(context.Background(), clientKey("client"), "not-a-client")
+ result = GetClientFromContext(badCtx)
+ assert.Nil(t, result)
+}
+
+// TestAuthMiddleware_ConcurrentRequests verifies that the race condition fix
+// prevents concurrent requests from bypassing replay protection
+func TestAuthMiddleware_ConcurrentRequests(t *testing.T) {
+ t.Parallel()
+
+ // This test verifies the mutex locking works correctly
+ // We'll just test that concurrent access to the mutex manager is safe
+
+ const numConcurrentRequests = 20
+ const deviceID = "test-device-concurrent"
+
+ done := make(chan struct{}, numConcurrentRequests)
+ var lockAcquired int32
+
+ for range numConcurrentRequests {
+ go func() {
+ defer func() { done <- struct{}{} }()
+
+ // Acquire device lock - this should be thread-safe
+ unlockDevice := LockClient(deviceID)
+
+ // Critical section - only one goroutine should be here at a time
+ current := atomic.AddInt32(&lockAcquired, 1)
+ if current != 1 {
+ t.Errorf("Race condition detected: %d goroutines in critical section", current)
+ }
+
+ // Simulate some work
+ time.Sleep(1 * time.Millisecond)
+
+ atomic.AddInt32(&lockAcquired, -1)
+ unlockDevice()
+ }()
+ }
+
+ // Wait for all requests to complete
+ for range numConcurrentRequests {
+ <-done
+ }
+
+ // Verify no race conditions occurred
+ assert.Equal(t, int32(0), atomic.LoadInt32(&lockAcquired), "All locks should be released")
+}
+
+// TestAuthMiddleware_ConcurrentAuthentication tests actual concurrent authentication
+// requests to ensure the race condition fix prevents replay attacks
+func TestAuthMiddleware_ConcurrentAuthentication(t *testing.T) {
+ t.Parallel()
+
+ // Setup
+ userDB := helpers.NewMockUserDBI()
+ db := &database.Database{UserDB: userDB}
+
+ // Create a test client
+ testClient := &database.Client{
+ ClientID: "test-client-concurrent",
+ ClientName: "Test Client",
+ AuthTokenHash: "test-hash",
+ SharedSecret: []byte("test-secret-32-bytes-long-key!!!"),
+ CurrentSeq: 5,
+ SeqWindow: make([]byte, 8),
+ NonceCache: []string{},
+ CreatedAt: time.Now(),
+ LastSeen: time.Now(),
+ }
+
+ // Mock database calls
+ userDB.On("GetClientByAuthToken", "test-token").Return(testClient, nil)
+ userDB.On("UpdateClientSequence",
+ testClient.ClientID,
+ mock.AnythingOfType("uint64"),
+ mock.AnythingOfType("[]uint8"),
+ mock.AnythingOfType("[]string")).Return(nil)
+
+ // Create test handler that tracks successful authentications
+ var successCount int32
+ handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ atomic.AddInt32(&successCount, 1)
+ w.WriteHeader(http.StatusOK)
+ _, _ = w.Write([]byte("success"))
+ })
+
+ middleware := AuthMiddleware(db)
+ wrappedHandler := middleware(handler)
+
+ // Create encrypted request with same sequence number (should cause replay detection)
+ payload := map[string]any{
+ "jsonrpc": "2.0",
+ "method": "test.method",
+ "id": 1,
+ "nonce": "test-nonce-concurrent",
+ "seq": uint64(6), // Same sequence for all requests
+ }
+ payloadJSON, _ := json.Marshal(payload)
+
+ encrypted, iv, err := EncryptPayload(payloadJSON, testClient.SharedSecret)
+ require.NoError(t, err)
+
+ encRequest := map[string]string{
+ "encrypted": encrypted,
+ "iv": iv,
+ "authToken": "test-token",
+ }
+ requestBody, _ := json.Marshal(encRequest)
+
+ // Run concurrent requests
+ const numRequests = 10
+ done := make(chan int, numRequests)
+
+ for i := range numRequests {
+ go func(_ int) {
+ req := httptest.NewRequest(http.MethodPost, "/api/test", bytes.NewReader(requestBody))
+ req.RemoteAddr = "192.168.1.100:12345" // Remote address to trigger auth
+
+ w := httptest.NewRecorder()
+ wrappedHandler.ServeHTTP(w, req)
+
+ done <- w.Code
+ }(i)
+ }
+
+ // Collect results
+ statusCodes := make([]int, 0, numRequests)
+ for range numRequests {
+ statusCodes = append(statusCodes, <-done)
+ }
+
+ // Only ONE request should succeed (200), others should fail with replay detection (400)
+ successStatusCount := 0
+ for _, code := range statusCodes {
+ if code == http.StatusOK {
+ successStatusCount++
+ } else {
+ assert.Equal(t, http.StatusBadRequest, code, "Failed requests should return 400 for replay detection")
+ }
+ }
+
+ assert.Equal(t, 1, successStatusCount, "Only one concurrent request should succeed")
+ assert.Equal(t, int32(1), atomic.LoadInt32(&successCount), "Handler should only be called once")
+
+ userDB.AssertExpectations(t)
+}
+
+// TestClientMutexManager_Cleanup verifies mutex cleanup works correctly
+func TestClientMutexManager_Cleanup(t *testing.T) {
+ t.Parallel()
+
+ dm := &ClientMutexManager{}
+
+ // Create some mutexes
+ mutex1 := dm.getClientMutex("device1")
+ mutex2 := dm.getClientMutex("device2")
+ mutex3 := dm.getClientMutex("device3")
+
+ require.NotNil(t, mutex1)
+ require.NotNil(t, mutex2)
+ require.NotNil(t, mutex3)
+
+ // Age some mutexes
+ mutex1.lastUsed = time.Now().Add(-31 * time.Minute) // Should be cleaned
+ mutex2.lastUsed = time.Now().Add(-20 * time.Minute) // Should remain
+ mutex3.lastUsed = time.Now() // Should remain
+
+ // Run cleanup
+ dm.cleanup()
+
+ // Check that only old mutex was removed
+ _, exists1 := dm.mutexes.Load("device1")
+ _, exists2 := dm.mutexes.Load("device2")
+ _, exists3 := dm.mutexes.Load("device3")
+
+ assert.False(t, exists1, "Old mutex should be cleaned up")
+ assert.True(t, exists2, "Recent mutex should remain")
+ assert.True(t, exists3, "Current mutex should remain")
+}
+
+// TestClientMutexManager_ConcurrentAccess verifies thread safety of mutex manager
+func TestClientMutexManager_ConcurrentAccess(t *testing.T) {
+ t.Parallel()
+
+ dm := &ClientMutexManager{}
+ const numGoroutines = 50
+ const deviceID = "concurrent-test-device"
+
+ done := make(chan struct{}, numGoroutines)
+
+ // Launch multiple goroutines that get/create mutex for same device
+ for range numGoroutines {
+ go func() {
+ defer func() { done <- struct{}{} }()
+
+ mutex := dm.getClientMutex(deviceID)
+ assert.NotNil(t, mutex)
+ if mutex != nil {
+ assert.Equal(t, deviceID, mutex.clientID)
+ }
+ }()
+ }
+
+ // Wait for all goroutines to complete
+ for range numGoroutines {
+ <-done
+ }
+
+ // Verify only one mutex was created for the device
+ value, exists := dm.mutexes.Load(deviceID)
+ assert.True(t, exists, "Mutex should exist")
+
+ mutex, ok := value.(*clientMutex)
+ require.True(t, ok)
+ assert.Equal(t, deviceID, mutex.clientID)
+}
diff --git a/pkg/api/middleware/ratelimit_test.go b/pkg/api/middleware/ratelimit_test.go
new file mode 100644
index 000000000..891d7002b
--- /dev/null
+++ b/pkg/api/middleware/ratelimit_test.go
@@ -0,0 +1,206 @@
+// Zaparoo Core
+// Copyright (c) 2025 The Zaparoo Project Contributors.
+// SPDX-License-Identifier: GPL-3.0-or-later
+//
+// This file is part of Zaparoo Core.
+//
+// Zaparoo Core is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// Zaparoo Core is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with Zaparoo Core. If not, see .
+
+package middleware
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "golang.org/x/time/rate"
+)
+
+func TestIPRateLimiter_BasicFunctionality(t *testing.T) {
+ t.Parallel()
+ limiter := NewIPRateLimiter()
+
+ // Get limiter for IP
+ rl := limiter.GetLimiter("192.168.1.100")
+ assert.NotNil(t, rl)
+
+ // Should allow initial requests up to burst size
+ for i := range BurstSize {
+ allowed := rl.Allow()
+ assert.True(t, allowed, "should allow request %d within burst", i+1)
+ }
+
+ // Should block additional requests beyond burst
+ blocked := rl.Allow()
+ assert.False(t, blocked, "should block request beyond burst size")
+}
+
+func TestIPRateLimiter_DifferentIPs(t *testing.T) {
+ t.Parallel()
+ limiter := NewIPRateLimiter()
+
+ // Get limiters for different IPs
+ rl1 := limiter.GetLimiter("192.168.1.100")
+ rl2 := limiter.GetLimiter("192.168.1.101")
+
+ // Should be different limiters
+ assert.NotSame(t, rl1, rl2)
+
+ // Exhaust first limiter
+ for range BurstSize {
+ rl1.Allow()
+ }
+
+ // First limiter should be blocked
+ assert.False(t, rl1.Allow())
+
+ // Second limiter should still allow requests
+ assert.True(t, rl2.Allow())
+}
+
+func TestIPRateLimiter_SameIPReuse(t *testing.T) {
+ t.Parallel()
+ limiter := NewIPRateLimiter()
+
+ // Get limiter for same IP twice
+ rl1 := limiter.GetLimiter("192.168.1.100")
+ rl2 := limiter.GetLimiter("192.168.1.100")
+
+ // Should be same limiter instance
+ assert.Same(t, rl1, rl2)
+}
+
+func TestHTTPRateLimitMiddleware_Allow(t *testing.T) {
+ t.Parallel()
+ limiter := NewIPRateLimiter()
+
+ handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ _, _ = w.Write([]byte("success"))
+ })
+
+ middleware := HTTPRateLimitMiddleware(limiter)
+ wrappedHandler := middleware(handler)
+
+ // Should allow initial requests
+ for i := range BurstSize {
+ req := httptest.NewRequest(http.MethodPost, "/api/test", http.NoBody)
+ req.RemoteAddr = "192.168.1.100:12345"
+
+ w := httptest.NewRecorder()
+ wrappedHandler.ServeHTTP(w, req)
+
+ assert.Equal(t, http.StatusOK, w.Code, "should allow request %d", i+1)
+ assert.Equal(t, "success", w.Body.String())
+ }
+}
+
+func TestHTTPRateLimitMiddleware_Block(t *testing.T) {
+ t.Parallel()
+ limiter := NewIPRateLimiter()
+
+ handler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
+ t.Error("handler should not be called when rate limited")
+ })
+
+ middleware := HTTPRateLimitMiddleware(limiter)
+ wrappedHandler := middleware(handler)
+
+ // Exhaust rate limit
+ ip := "192.168.1.100:12345"
+ ipLimiter := limiter.GetLimiter("192.168.1.100")
+ for range BurstSize {
+ ipLimiter.Allow()
+ }
+
+ // Next request should be blocked
+ req := httptest.NewRequest(http.MethodPost, "/api/test", http.NoBody)
+ req.RemoteAddr = ip
+
+ w := httptest.NewRecorder()
+ wrappedHandler.ServeHTTP(w, req)
+
+ assert.Equal(t, http.StatusTooManyRequests, w.Code)
+ assert.Contains(t, w.Body.String(), "Too Many Requests")
+}
+
+func TestIPRateLimiter_Cleanup(t *testing.T) {
+ t.Parallel()
+ limiter := NewIPRateLimiter()
+
+ // Add a limiter with old timestamp manually
+ limiter.limiters["old.ip"] = &rateLimiterEntry{
+ limiter: rate.NewLimiter(rate.Limit(float64(RequestsPerMinute)/60.0), BurstSize),
+ lastSeen: time.Now().Add(-15 * time.Minute), // Old timestamp
+ }
+
+ // Add a limiter with recent timestamp manually
+ limiter.limiters["new.ip"] = &rateLimiterEntry{
+ limiter: rate.NewLimiter(rate.Limit(float64(RequestsPerMinute)/60.0), BurstSize),
+ lastSeen: time.Now(), // Recent timestamp
+ }
+
+ // Verify both exist
+ assert.Len(t, limiter.limiters, 2)
+
+ // Run cleanup
+ limiter.Cleanup()
+
+ // Old entry should be removed, new one should remain
+ assert.Len(t, limiter.limiters, 1)
+ assert.Contains(t, limiter.limiters, "new.ip")
+ assert.NotContains(t, limiter.limiters, "old.ip")
+}
+
+func TestHTTPRateLimitMiddleware_IPExtraction(t *testing.T) {
+ t.Parallel()
+ limiter := NewIPRateLimiter()
+
+ handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ })
+
+ middleware := HTTPRateLimitMiddleware(limiter)
+ wrappedHandler := middleware(handler)
+
+ tests := []struct {
+ name string
+ remoteAddr string
+ expectedIP string
+ }{
+ {"with port", "192.168.1.100:12345", "192.168.1.100"},
+ {"without port", "192.168.1.100", "192.168.1.100"},
+ {"IPv6 with port", "[2001:db8::1]:8080", "2001:db8::1"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ req := httptest.NewRequest(http.MethodPost, "/api/test", http.NoBody)
+ req.RemoteAddr = tt.remoteAddr
+
+ w := httptest.NewRecorder()
+ wrappedHandler.ServeHTTP(w, req)
+
+ // Should succeed (IP extraction worked)
+ assert.Equal(t, http.StatusOK, w.Code)
+
+ // Verify the correct IP was used for rate limiting
+ ipLimiter := limiter.GetLimiter(tt.expectedIP)
+ assert.NotNil(t, ipLimiter)
+ })
+ }
+}
diff --git a/pkg/api/middleware/replay_filter.go b/pkg/api/middleware/replay_filter.go
new file mode 100644
index 000000000..7bf923dec
--- /dev/null
+++ b/pkg/api/middleware/replay_filter.go
@@ -0,0 +1,93 @@
+// Zaparoo Core
+// Copyright (c) 2025 The Zaparoo Project Contributors.
+// SPDX-License-Identifier: GPL-3.0-or-later
+//
+// This file is part of Zaparoo Core.
+//
+// Zaparoo Core is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// Zaparoo Core is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with Zaparoo Core. If not, see .
+
+package middleware
+
+/*
+ATTRIBUTION NOTICE:
+
+This file contains code derived from WireGuard's anti-replay filter implementation.
+
+Original work: Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
+Original license: MIT License
+Source: https://github.com/WireGuard/wireguard-go/blob/master/replay/replay.go
+
+The derived portions remain under their original MIT license as required.
+The combined work is distributed under GPL-3.0-or-later as indicated in the file header.
+*/
+
+type block uint64
+
+const (
+ blockBitLog = 6
+ blockBits = 1 << blockBitLog
+ ringBlocks = 1 << 7
+ windowSize = (ringBlocks - 1) * blockBits
+ blockMask = ringBlocks - 1
+ bitMask = blockBits - 1
+)
+
+// Filter implements an anti-replay mechanism as specified in RFC 6479.
+// It rejects replayed messages by checking if message counter value is within
+// a sliding window of previously received messages.
+type Filter struct {
+ last uint64
+ ring [ringBlocks]block
+}
+
+// Reset resets the filter to an empty state.
+func (f *Filter) Reset() {
+ f.last = 0
+ for i := range f.ring {
+ f.ring[i] = 0
+ }
+}
+
+// ValidateCounter checks if a given counter should be accepted.
+// It automatically rejects counters >= the specified limit.
+func (f *Filter) ValidateCounter(counter, limit uint64) bool {
+ if counter >= limit {
+ return false
+ }
+
+ indexBlock := counter >> blockBitLog
+
+ if counter > f.last {
+ // New highest counter - shift window
+ current := f.last >> blockBitLog
+ diff := indexBlock - current
+ if diff > ringBlocks {
+ diff = ringBlocks
+ }
+ for i := current + 1; i <= current+diff; i++ {
+ f.ring[i&blockMask] = 0
+ }
+ f.last = counter
+ } else if f.last-counter > windowSize {
+ // Too old
+ return false
+ }
+
+ indexBlock &= blockMask
+ indexBit := counter & bitMask
+ old := f.ring[indexBlock]
+ updated := old | 1<.
+
+package middleware
+
+import (
+ "testing"
+)
+
+const RejectAfterMessages = 1<<64 - 1<<13 - 1
+
+func TestReplayFilter(t *testing.T) {
+ t.Parallel()
+ var filter Filter
+ const tLim = windowSize + 1
+ testNumber := 0
+
+ testFunc := func(n uint64, expected bool) {
+ testNumber++
+ if filter.ValidateCounter(n, RejectAfterMessages) != expected {
+ t.Fatal("Test", testNumber, "failed", n, expected)
+ }
+ }
+
+ filter.Reset()
+
+ testFunc(0, true)
+ testFunc(1, true)
+ testFunc(1, false)
+ testFunc(9, true)
+ testFunc(8, true)
+ testFunc(7, true)
+ testFunc(7, false)
+ testFunc(tLim, true)
+ testFunc(tLim-1, true)
+ testFunc(tLim-1, false)
+ testFunc(tLim-2, true)
+ testFunc(2, true)
+ testFunc(2, false)
+ testFunc(tLim+16, true)
+ testFunc(3, false)
+ testFunc(tLim+16, false)
+ testFunc(tLim*4, true)
+ testFunc(tLim*4-(tLim-1), true)
+ testFunc(10, false)
+ testFunc(tLim*4-tLim, false)
+ testFunc(tLim*4-(tLim+1), false)
+ testFunc(tLim*4-(tLim-2), true)
+ testFunc(tLim*4+1-tLim, false)
+ testFunc(0, false)
+ testFunc(RejectAfterMessages, false)
+ testFunc(RejectAfterMessages-1, true)
+ testFunc(RejectAfterMessages, false)
+ testFunc(RejectAfterMessages-1, false)
+ testFunc(RejectAfterMessages-2, true)
+ testFunc(RejectAfterMessages+1, false)
+ testFunc(RejectAfterMessages+2, false)
+ testFunc(RejectAfterMessages-2, false)
+ testFunc(RejectAfterMessages-3, true)
+ testFunc(0, false)
+
+ t.Log("Bulk test 1")
+ filter.Reset()
+ testNumber = 0
+ for i := uint64(1); i <= windowSize; i++ {
+ testFunc(i, true)
+ }
+ testFunc(0, true)
+ testFunc(0, false)
+
+ t.Log("Bulk test 2")
+ filter.Reset()
+ testNumber = 0
+ for i := uint64(2); i <= windowSize+1; i++ {
+ testFunc(i, true)
+ }
+ testFunc(1, true)
+ testFunc(0, false)
+
+ t.Log("Bulk test 3")
+ filter.Reset()
+ testNumber = 0
+ for i := uint64(windowSize + 1); i > 0; i-- {
+ testFunc(i, true)
+ }
+
+ t.Log("Bulk test 4")
+ filter.Reset()
+ testNumber = 0
+ for i := uint64(windowSize + 2); i > 1; i-- {
+ testFunc(i, true)
+ }
+ testFunc(0, false)
+
+ t.Log("Bulk test 5")
+ filter.Reset()
+ testNumber = 0
+ for i := uint64(windowSize); i > 0; i-- {
+ testFunc(i, true)
+ }
+ testFunc(windowSize+1, true)
+ testFunc(0, false)
+
+ t.Log("Bulk test 6")
+ filter.Reset()
+ testNumber = 0
+ for i := uint64(windowSize); i > 0; i-- {
+ testFunc(i, true)
+ }
+ testFunc(0, true)
+ testFunc(windowSize+1, true)
+}
diff --git a/pkg/api/middleware/replay_protection.go b/pkg/api/middleware/replay_protection.go
new file mode 100644
index 000000000..d28146472
--- /dev/null
+++ b/pkg/api/middleware/replay_protection.go
@@ -0,0 +1,113 @@
+// Zaparoo Core
+// Copyright (c) 2025 The Zaparoo Project Contributors.
+// SPDX-License-Identifier: GPL-3.0-or-later
+//
+// This file is part of Zaparoo Core.
+//
+// Zaparoo Core is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// Zaparoo Core is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with Zaparoo Core. If not, see .
+
+package middleware
+
+import (
+ "encoding/binary"
+ "slices"
+
+ "github.com/ZaparooProject/zaparoo-core/v2/pkg/database"
+)
+
+const (
+ MaxSequenceNumber = 1<<64 - 1<<13 - 1 // Same as WireGuard's RejectAfterMessages
+ NonceCacheSize = 100 // Maximum number of cached nonces
+)
+
+// ReplayProtector combines WireGuard's sequence number validation with nonce replay protection
+type ReplayProtector struct {
+ nonceCache []string
+ filter Filter
+ lastSeq uint64
+}
+
+// NewReplayProtector creates a new replay protector from client state
+func NewReplayProtector(client *database.Client) *ReplayProtector {
+ rp := &ReplayProtector{
+ nonceCache: make([]string, len(client.NonceCache)),
+ lastSeq: client.CurrentSeq,
+ }
+ copy(rp.nonceCache, client.NonceCache)
+
+ // Initialize filter from stored ring buffer state
+ if len(client.SeqWindow) >= 8+len(rp.filter.ring)*8 {
+ storedLastSeq := binary.LittleEndian.Uint64(client.SeqWindow[0:8])
+ if storedLastSeq > 0 {
+ // Load stored state: first 8 bytes = last sequence, then 8 bytes per ring block
+ rp.filter.last = storedLastSeq
+ for i := range rp.filter.ring {
+ offset := 8 + i*8
+ rp.filter.ring[i] = block(binary.LittleEndian.Uint64(client.SeqWindow[offset : offset+8]))
+ }
+ } else {
+ // Buffer exists but uninitialized - treat as new client
+ rp.filter.Reset()
+ if client.CurrentSeq > 0 {
+ rp.filter.ValidateCounter(client.CurrentSeq, MaxSequenceNumber)
+ }
+ }
+ } else {
+ // Buffer too small - initialize fresh
+ rp.filter.Reset()
+ if client.CurrentSeq > 0 {
+ rp.filter.ValidateCounter(client.CurrentSeq, MaxSequenceNumber)
+ }
+ }
+
+ return rp
+}
+
+// ValidateSequenceAndNonce validates both sequence number and nonce for replay protection
+func (rp *ReplayProtector) ValidateSequenceAndNonce(seq uint64, nonce string) bool {
+ // Check nonce first (replay protection)
+ if slices.Contains(rp.nonceCache, nonce) {
+ return false
+ }
+
+ // Then validate sequence number using WireGuard's algorithm
+ return rp.filter.ValidateCounter(seq, MaxSequenceNumber)
+}
+
+// UpdateSequenceAndNonce updates the replay protector state after successful validation
+func (rp *ReplayProtector) UpdateSequenceAndNonce(seq uint64, nonce string) {
+ // Update nonce cache (keep last NonceCacheSize nonces)
+ rp.nonceCache = append(rp.nonceCache, nonce)
+ if len(rp.nonceCache) > NonceCacheSize {
+ rp.nonceCache = rp.nonceCache[1:] // Remove oldest
+ }
+
+ // Sequence is already updated by ValidateCounter if it was accepted
+ rp.lastSeq = max(rp.lastSeq, seq)
+}
+
+// GetStateForDatabase returns the state that should be stored in the database
+func (rp *ReplayProtector) GetStateForDatabase() (currentSeq uint64, seqWindow []byte, nonceCache []string) {
+ // Serialize ring buffer state
+ // Format: [8 bytes last seq][8 bytes per ring block]
+ seqWindow = make([]byte, 8+len(rp.filter.ring)*8)
+ binary.LittleEndian.PutUint64(seqWindow[0:8], rp.filter.last)
+
+ for i, block := range rp.filter.ring[:] {
+ offset := 8 + i*8
+ binary.LittleEndian.PutUint64(seqWindow[offset:offset+8], uint64(block))
+ }
+
+ return rp.filter.last, seqWindow, rp.nonceCache
+}
diff --git a/pkg/api/pairing.go b/pkg/api/pairing.go
new file mode 100644
index 000000000..335c604aa
--- /dev/null
+++ b/pkg/api/pairing.go
@@ -0,0 +1,355 @@
+// Zaparoo Core
+// Copyright (c) 2025 The Zaparoo Project Contributors.
+// SPDX-License-Identifier: GPL-3.0-or-later
+//
+// This file is part of Zaparoo Core.
+//
+// Zaparoo Core is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// Zaparoo Core is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with Zaparoo Core. If not, see .
+
+package api
+
+import (
+ "context"
+ "crypto/rand"
+ "crypto/sha256"
+ "encoding/hex"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "net/http"
+ "regexp"
+ "runtime"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/ZaparooProject/zaparoo-core/v2/pkg/database"
+ "github.com/google/uuid"
+ "github.com/rs/zerolog/log"
+ "golang.org/x/crypto/argon2"
+ "golang.org/x/crypto/hkdf"
+)
+
+const (
+ PairingTokenExpiry = 5 * time.Minute
+ PairingAttemptLimit = 10
+)
+
+type PairingSession struct {
+ CreatedAt time.Time
+ Token string
+ Challenge []byte
+ Salt []byte
+ Attempts int
+}
+
+type PairingManager struct {
+ sessions map[string]*PairingSession
+ mu sync.RWMutex
+}
+
+type PairingInitiateRequest struct {
+ ClientName string `json:"clientName"`
+}
+
+type PairingInitiateResponse struct {
+ PairingToken string `json:"pairingToken"`
+ ExpiresIn int `json:"expiresIn"`
+}
+
+type PairingCompleteRequest struct {
+ PairingToken string `json:"pairingToken"`
+ Verifier string `json:"verifier"`
+ ClientName string `json:"clientName"`
+}
+
+type PairingCompleteResponse struct {
+ ClientID string `json:"clientId"`
+ AuthToken string `json:"authToken"`
+ SharedSecret string `json:"sharedSecret"` // Base64 encoded
+}
+
+var pairingManager = &PairingManager{
+ sessions: make(map[string]*PairingSession),
+}
+
+var clientNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
+
+func validateClientName(name string) error {
+ name = strings.TrimSpace(name)
+
+ if name == "" {
+ return errors.New("client name cannot be empty")
+ }
+
+ if len(name) > 100 {
+ return errors.New("client name too long (max 100 characters)")
+ }
+
+ if !clientNameRegex.MatchString(name) {
+ return errors.New("client name contains invalid characters (only letters, numbers, underscore, and dash)")
+ }
+
+ return nil
+}
+
+// StartPairingCleanup starts the pairing session cleanup routine
+func StartPairingCleanup(ctx context.Context) {
+ go pairingManager.cleanup(ctx)
+}
+
+func (pm *PairingManager) createSession() (*PairingSession, error) {
+ pm.mu.Lock()
+ defer pm.mu.Unlock()
+
+ // Generate random token and challenge
+ token := uuid.New().String()
+ challenge := make([]byte, 32)
+ if _, err := rand.Read(challenge); err != nil {
+ return nil, fmt.Errorf("failed to generate challenge: %w", err)
+ }
+
+ // Generate random salt (32 bytes for SHA-256)
+ salt := make([]byte, 32)
+ if _, err := rand.Read(salt); err != nil {
+ return nil, fmt.Errorf("failed to generate salt: %w", err)
+ }
+
+ session := &PairingSession{
+ Token: token,
+ Challenge: challenge,
+ Salt: salt,
+ CreatedAt: time.Now(),
+ Attempts: 0,
+ }
+
+ pm.sessions[token] = session
+
+ log.Debug().Str("token", token[:8]+"...").Msg("created pairing session")
+ return session, nil
+}
+
+func (pm *PairingManager) consumeSession(token string) (*PairingSession, bool) {
+ pm.mu.Lock()
+ defer pm.mu.Unlock()
+
+ session, exists := pm.sessions[token]
+ if !exists {
+ return nil, false
+ }
+
+ // Check if expired
+ if time.Since(session.CreatedAt) > PairingTokenExpiry {
+ delete(pm.sessions, token)
+ return nil, false
+ }
+
+ // Increment attempts
+ session.Attempts++
+
+ // Check attempt limit
+ if session.Attempts >= PairingAttemptLimit {
+ delete(pm.sessions, token)
+ return nil, false
+ }
+
+ return session, true
+}
+
+func (pm *PairingManager) cleanup(ctx context.Context) {
+ ticker := time.NewTicker(1 * time.Minute)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ctx.Done():
+ log.Debug().Msg("pairing manager cleanup routine stopped")
+ return
+ case <-ticker.C:
+ pm.mu.Lock()
+ now := time.Now()
+ for token, session := range pm.sessions {
+ if now.Sub(session.CreatedAt) > PairingTokenExpiry {
+ // Zero sensitive data before deletion
+ for i := range session.Challenge {
+ session.Challenge[i] = 0
+ }
+ for i := range session.Salt {
+ session.Salt[i] = 0
+ }
+ delete(pm.sessions, token)
+ }
+ }
+ pm.mu.Unlock()
+ }
+ }
+}
+
+func handlePairingInitiate(_ *database.Database) http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ session, err := pairingManager.createSession()
+ if err != nil {
+ log.Error().Err(err).Msg("failed to create pairing session")
+ http.Error(w, "Internal server error", http.StatusInternalServerError)
+ return
+ }
+
+ response := PairingInitiateResponse{
+ PairingToken: session.Token,
+ ExpiresIn: int(PairingTokenExpiry.Seconds()),
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ if err := json.NewEncoder(w).Encode(response); err != nil {
+ log.Error().Err(err).Msg("failed to encode pairing response")
+ http.Error(w, "Internal server error", http.StatusInternalServerError)
+ return
+ }
+
+ log.Info().Str("token", session.Token[:8]+"...").Msg("pairing session initiated")
+ }
+}
+
+func handlePairingComplete(db *database.Database) http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ var req PairingCompleteRequest
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ http.Error(w, "Invalid request format", http.StatusBadRequest)
+ return
+ }
+
+ if req.PairingToken == "" || req.Verifier == "" || req.ClientName == "" {
+ http.Error(w, "Missing required fields", http.StatusBadRequest)
+ return
+ }
+
+ // Validate client name
+ if err := validateClientName(req.ClientName); err != nil {
+ log.Warn().
+ Str("client_name", req.ClientName).
+ Str("remote_addr", r.RemoteAddr).
+ Err(err).
+ Msg("SECURITY: invalid client name provided")
+ http.Error(w, fmt.Sprintf("Invalid client name: %s", err.Error()), http.StatusBadRequest)
+ return
+ }
+
+ // Get and consume pairing session
+ session, exists := pairingManager.consumeSession(req.PairingToken)
+ if !exists {
+ http.Error(w, "Invalid or expired pairing token", http.StatusBadRequest)
+ return
+ }
+
+ // First, use Argon2id for strong key derivation from verifier
+ // Using security parameters recommended by OWASP (2023)
+ verifierKey := argon2.IDKey(
+ []byte(req.Verifier), // password
+ session.Salt, // salt (32 bytes)
+ 3, // time parameter (iterations)
+ 64*1024, // memory parameter (64 MB)
+ 4, // parallelism parameter
+ 32, // key length (32 bytes)
+ )
+ // Ensure verifierKey is zeroed after use
+ defer func() {
+ for i := range verifierKey {
+ verifierKey[i] = 0
+ }
+ runtime.KeepAlive(verifierKey) // Prevent compiler optimization
+ }()
+
+ // Then combine with challenge using HKDF for domain separation
+ combinedSecret := make([]byte, len(session.Challenge)+len(verifierKey))
+ copy(combinedSecret, session.Challenge)
+ copy(combinedSecret[len(session.Challenge):], verifierKey)
+ // Ensure combinedSecret is zeroed after use
+ defer func() {
+ for i := range combinedSecret {
+ combinedSecret[i] = 0
+ }
+ runtime.KeepAlive(combinedSecret) // Prevent compiler optimization
+ }()
+
+ sharedSecret := make([]byte, 32) // 256 bits for AES-256
+
+ // Construct context-specific info string for domain separation
+ info := []byte("zaparoo-pairing-v1|" + req.PairingToken + "|" + req.ClientName)
+
+ kdf := hkdf.New(sha256.New, combinedSecret, session.Salt, info)
+ if _, err := kdf.Read(sharedSecret); err != nil {
+ log.Error().Err(err).Msg("failed to derive shared secret")
+ http.Error(w, "Internal server error", http.StatusInternalServerError)
+ return
+ }
+
+ // Generate auth token
+ authToken := uuid.New().String()
+
+ // Create client in database
+ client, err := db.UserDB.CreateClient(req.ClientName, authToken, sharedSecret)
+ if err != nil {
+ log.Error().Err(err).Msg("failed to create client")
+ http.Error(w, "Internal server error", http.StatusInternalServerError)
+ return
+ }
+
+ // Zero sensitive session data before removing
+ for i := range session.Challenge {
+ session.Challenge[i] = 0
+ }
+ for i := range session.Salt {
+ session.Salt[i] = 0
+ }
+
+ // Remove session from manager
+ pairingManager.mu.Lock()
+ delete(pairingManager.sessions, req.PairingToken)
+ pairingManager.mu.Unlock()
+
+ response := PairingCompleteResponse{
+ ClientID: client.ClientID,
+ AuthToken: authToken,
+ SharedSecret: hex.EncodeToString(sharedSecret),
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ if err := json.NewEncoder(w).Encode(response); err != nil {
+ log.Error().Err(err).Msg("failed to encode pairing response")
+ http.Error(w, "Internal server error", http.StatusInternalServerError)
+ return
+ }
+
+ log.Info().
+ Str("client_id", client.ClientID).
+ Str("client_name", client.ClientName).
+ Msg("client paired successfully")
+ }
+}
+
+// QRCodeData represents the data embedded in the pairing QR code
+type QRCodeData struct {
+ Address string `json:"address"`
+ Token string `json:"token"`
+}
diff --git a/pkg/api/server.go b/pkg/api/server.go
index 4e0dc2f4d..86007649d 100644
--- a/pkg/api/server.go
+++ b/pkg/api/server.go
@@ -22,6 +22,7 @@ package api
import (
"bytes"
"context"
+ "crypto/rand"
"encoding/json"
"errors"
"fmt"
@@ -51,8 +52,10 @@ import (
"github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/cors"
"github.com/google/uuid"
+ "github.com/gorilla/csrf"
"github.com/olahol/melody"
"github.com/rs/zerolog/log"
+ "github.com/unrolled/secure"
)
var allowedOrigins = []string{
@@ -427,7 +430,7 @@ func buildDynamicAllowedOrigins(baseOrigins, localIPs []string, port int, custom
}
// broadcastNotifications consumes and broadcasts all incoming API
-// notifications to all connected clients.
+// notifications to all connected clients with appropriate encryption.
func broadcastNotifications(
st *state.State,
session *melody.Melody,
@@ -451,11 +454,61 @@ func broadcastNotifications(
continue
}
- // TODO: this will not work with encryption
- err = session.Broadcast(data)
- if err != nil {
- log.Error().Err(err).Msg("broadcasting notification")
- }
+ // Broadcast to localhost sessions (unencrypted)
+ _ = session.BroadcastFilter(data, func(s *melody.Session) bool {
+ rawIP := strings.SplitN(s.Request.RemoteAddr, ":", 2)
+ clientIP := net.ParseIP(rawIP[0])
+ return clientIP.IsLoopback()
+ })
+
+ // Broadcast to authenticated remote sessions (encrypted)
+ _ = session.BroadcastFilter(nil, func(s *melody.Session) bool {
+ // Check if this is a remote connection
+ rawIP := strings.SplitN(s.Request.RemoteAddr, ":", 2)
+ clientIP := net.ParseIP(rawIP[0])
+ if clientIP.IsLoopback() {
+ return false // Skip localhost
+ }
+
+ // Check if session is authenticated
+ client, authenticated := s.Get("client")
+ if !authenticated {
+ return false // Skip unauthenticated
+ }
+
+ // Encrypt notification for this session
+ clientObj, ok := client.(*database.Client)
+ if !ok {
+ log.Error().Msg("invalid client type in session")
+ return false
+ }
+ encrypted, iv, err := apimiddleware.EncryptPayload(data, clientObj.SharedSecret)
+ if err != nil {
+ log.Error().Err(err).Str("client_id", clientObj.ClientID).Msg("failed to encrypt notification")
+ return false
+ }
+
+ encResponse := apimiddleware.EncryptedRequest{
+ Encrypted: encrypted,
+ IV: iv,
+ AuthToken: "", // Not needed for notifications
+ }
+
+ encData, err := json.Marshal(encResponse)
+ if err != nil {
+ log.Error().Err(err).Str("client_id", clientObj.ClientID).
+ Msg("failed to marshal encrypted notification")
+ return false
+ }
+
+ // Send encrypted data to this session
+ if err := s.Write(encData); err != nil {
+ log.Error().Err(err).Str("client_id", clientObj.ClientID).
+ Msg("failed to send encrypted notification")
+ }
+
+ return false // Don't include in broadcast since we already sent manually
+ })
}
}
}
@@ -517,6 +570,77 @@ func processRequestObject(
// handleWSMessage parses all incoming WS requests, identifies what type of
// JSON-RPC object they may be and forwards them to the appropriate function
// to handle that type of message.
+// handleWSPing handles ping messages for heartbeat operation
+func handleWSPing(session *melody.Session, msg []byte) bool {
+ if bytes.Equal(msg, []byte("ping")) {
+ err := session.Write([]byte("pong"))
+ if err != nil {
+ log.Error().Err(err).Msg("sending pong")
+ }
+ return true
+ }
+ return false
+}
+
+// handleWSAuthentication handles authentication for remote WebSocket connections
+func handleWSRemoteAuth(session *melody.Session, msg []byte, db *database.Database) ([]byte, error) {
+ client, authenticated := session.Get("client")
+ if !authenticated {
+ // First message must be authentication
+ err := handleWSAuthentication(session, msg, db)
+ if err != nil {
+ log.Warn().Err(err).
+ Str("remote_addr", session.Request.RemoteAddr).
+ Str("user_agent", session.Request.Header.Get("User-Agent")).
+ Msg("SECURITY: WebSocket authentication failed")
+ _ = session.Close()
+ }
+ return nil, err
+ }
+
+ // Check session timeout
+ if isSessionExpired(session) {
+ log.Warn().
+ Str("remote_addr", session.Request.RemoteAddr).
+ Msg("SECURITY: WebSocket session expired - closing connection")
+ _ = session.Close()
+ return nil, errors.New("session expired")
+ }
+
+ // Update last activity time
+ session.Set("auth_time", time.Now())
+
+ // Decrypt message for authenticated remote connection
+ clientObj, ok := client.(*database.Client)
+ if !ok {
+ log.Error().Msg("invalid client type in session")
+ return nil, errors.New("invalid client type")
+ }
+
+ decryptedMsg, err := handleWSDecryption(session, msg, clientObj, db)
+ if err != nil {
+ log.Error().Err(err).Msg("WebSocket decryption failed")
+ return nil, err
+ }
+
+ return decryptedMsg, nil
+}
+
+// sendWSResponseByType sends response based on connection type (local/remote)
+func sendWSResponseByType(session *melody.Session, isLocal bool, id uuid.UUID, resp any) error {
+ if !isLocal {
+ if client, authenticated := session.Get("client"); authenticated {
+ clientObj, ok := client.(*database.Client)
+ if !ok {
+ log.Error().Msg("invalid client type in session")
+ return errors.New("invalid client type")
+ }
+ return sendWSResponseEncrypted(session, id, resp, clientObj)
+ }
+ }
+ return sendWSResponse(session, id, resp)
+}
+
func handleWSMessage(
methodMap *MethodMap,
platform platforms.Platform,
@@ -536,24 +660,40 @@ func handleWSMessage(
}
}()
- // ping command for heartbeat operation
- if bytes.Equal(msg, []byte("ping")) {
- err := session.Write([]byte("pong"))
- if err != nil {
- log.Error().Err(err).Msg("sending pong")
- }
+ // Handle ping/pong heartbeat
+ if handleWSPing(session, msg) {
return
}
+ // Determine if connection is local
rawIP := strings.SplitN(session.Request.RemoteAddr, ":", 2)
clientIP := net.ParseIP(rawIP[0])
+ isLocal := clientIP.IsLoopback()
+
+ // Handle authentication and decryption for remote connections
+ if !isLocal {
+ decryptedMsg, err := handleWSRemoteAuth(session, msg, db)
+ if err != nil {
+ err := sendWSError(session, uuid.Nil, JSONRPCErrorInvalidRequest)
+ if err != nil {
+ log.Error().Err(err).Msg("error sending auth error response")
+ }
+ return
+ }
+ if decryptedMsg == nil {
+ return // Authentication in progress
+ }
+ msg = decryptedMsg
+ }
+
+ // Process the request
env := requests.RequestEnv{
Platform: platform,
Config: cfg,
State: st,
Database: db,
TokenQueue: inTokenQueue,
- IsLocal: clientIP.IsLoopback(),
+ IsLocal: isLocal,
}
id, resp, rpcError := processRequestObject(methodMap, env, msg)
@@ -563,7 +703,7 @@ func handleWSMessage(
log.Error().Err(err).Msg("error sending error response")
}
} else {
- err := sendWSResponse(session, id, resp)
+ err := sendWSResponseByType(session, isLocal, id, resp)
if err != nil {
log.Error().Err(err).Msg("error sending response")
}
@@ -571,6 +711,177 @@ func handleWSMessage(
}
}
+type WSAuthMessage struct {
+ AuthToken string `json:"authToken"`
+}
+
+type CSRFTokenResponse struct {
+ CSRFToken string `json:"csrfToken"`
+}
+
+const sessionTimeout = 30 * time.Minute
+
+func isSessionExpired(session *melody.Session) bool {
+ authTime, exists := session.Get("auth_time")
+ if !exists {
+ return true // No auth time means expired
+ }
+
+ timestamp, ok := authTime.(time.Time)
+ if !ok {
+ return true // Invalid timestamp means expired
+ }
+
+ return time.Since(timestamp) > sessionTimeout
+}
+
+func handleWSAuthentication(session *melody.Session, msg []byte, db *database.Database) error {
+ var authMsg WSAuthMessage
+ if err := json.Unmarshal(msg, &authMsg); err != nil {
+ return fmt.Errorf("invalid auth message format: %w", err)
+ }
+
+ if authMsg.AuthToken == "" {
+ return errors.New("missing auth token")
+ }
+
+ // Validate auth token and get client
+ client, err := db.UserDB.GetClientByAuthToken(authMsg.AuthToken)
+ if err != nil {
+ return fmt.Errorf("invalid auth token: %w", err)
+ }
+
+ // Store client and auth timestamp in session
+ session.Set("client", client)
+ session.Set("auth_time", time.Now())
+
+ // Send authentication success response
+ authResponse := map[string]any{
+ "authenticated": true,
+ "client_id": client.ClientID,
+ }
+
+ responseData, _ := json.Marshal(authResponse)
+ err = session.Write(responseData)
+ if err != nil {
+ return fmt.Errorf("failed to send auth response: %w", err)
+ }
+
+ log.Info().
+ Str("client_id", client.ClientID).
+ Str("remote_addr", session.Request.RemoteAddr).
+ Msg("SECURITY: WebSocket authenticated successfully")
+ return nil
+}
+
+func handleWSDecryption(_ *melody.Session, msg []byte, client *database.Client, db *database.Database) ([]byte, error) {
+ // Parse encrypted message
+ var encMsg apimiddleware.EncryptedRequest
+ if err := json.Unmarshal(msg, &encMsg); err != nil {
+ return nil, fmt.Errorf("invalid encrypted message format: %w", err)
+ }
+
+ // Decrypt payload (we can reuse the middleware function by creating a temporary import)
+ decryptedPayload, err := apimiddleware.DecryptPayload(encMsg.Encrypted, encMsg.IV, client.SharedSecret)
+ if err != nil {
+ return nil, fmt.Errorf("decryption failed: %w", err)
+ }
+
+ // Parse and validate sequence/nonce
+ var payload apimiddleware.DecryptedPayload
+ if unmarshalErr := json.Unmarshal(decryptedPayload, &payload); unmarshalErr != nil {
+ return nil, fmt.Errorf("invalid decrypted payload: %w", unmarshalErr)
+ }
+
+ // Acquire client lock to prevent race conditions
+ // between validation and database update
+ unlockClient := apimiddleware.LockClient(client.ClientID)
+ defer unlockClient()
+
+ // Re-fetch client state under lock to get latest sequence/nonce state
+ freshClient, err := db.UserDB.GetClientByID(client.ClientID)
+ if err != nil {
+ return nil, fmt.Errorf("failed to re-fetch client under lock: %w", err)
+ }
+
+ // Create replay protector from fresh client state
+ replayProtector := apimiddleware.NewReplayProtector(freshClient)
+
+ // Validate sequence and nonce with fresh client state
+ if !replayProtector.ValidateSequenceAndNonce(payload.Seq, payload.Nonce) {
+ return nil, errors.New("invalid sequence or replay detected")
+ }
+
+ // Update replay protector state
+ replayProtector.UpdateSequenceAndNonce(payload.Seq, payload.Nonce)
+
+ // Get updated state for database storage
+ currentSeq, seqWindow, nonceCache := replayProtector.GetStateForDatabase()
+ freshClient.CurrentSeq = currentSeq
+ freshClient.SeqWindow = seqWindow
+ freshClient.NonceCache = nonceCache
+
+ if updateErr := db.UserDB.UpdateClientSequence(
+ freshClient.ClientID, freshClient.CurrentSeq, freshClient.SeqWindow, freshClient.NonceCache,
+ ); updateErr != nil {
+ return nil, fmt.Errorf("failed to update client state: %w", updateErr)
+ }
+
+ // Return JSON-RPC payload without sequence/nonce
+ originalPayload := map[string]any{
+ "jsonrpc": payload.JSONRPC,
+ "method": payload.Method,
+ "id": payload.ID,
+ }
+ if payload.Params != nil {
+ originalPayload["params"] = payload.Params
+ }
+
+ result, err := json.Marshal(originalPayload)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal payload: %w", err)
+ }
+ return result, nil
+}
+
+func sendWSResponseEncrypted(session *melody.Session, id uuid.UUID, result any, client *database.Client) error {
+ // Create response object
+ resp := models.ResponseObject{
+ JSONRPC: "2.0",
+ ID: id,
+ Result: result,
+ }
+
+ // Marshal to JSON
+ data, err := json.Marshal(resp)
+ if err != nil {
+ return fmt.Errorf("error marshalling response: %w", err)
+ }
+
+ // Encrypt the response
+ encrypted, iv, err := apimiddleware.EncryptPayload(data, client.SharedSecret)
+ if err != nil {
+ return fmt.Errorf("failed to encrypt response: %w", err)
+ }
+
+ // Send encrypted response
+ encResponse := apimiddleware.EncryptedRequest{
+ Encrypted: encrypted,
+ IV: iv,
+ AuthToken: "", // Not needed for responses
+ }
+
+ encData, err := json.Marshal(encResponse)
+ if err != nil {
+ return fmt.Errorf("failed to marshal encrypted response: %w", err)
+ }
+
+ if err := session.Write(encData); err != nil {
+ return fmt.Errorf("failed to send encrypted response: %w", err)
+ }
+ return nil
+}
+
func handlePostRequest(
methodMap *MethodMap,
platform platforms.Platform,
@@ -632,7 +943,7 @@ func handlePostRequest(
}
w.Header().Set("Content-Type", "application/json")
- w.WriteHeader(http.StatusInternalServerError)
+ w.WriteHeader(http.StatusOK)
_, err = w.Write(respBody)
if err != nil {
log.Error().Err(err).Msg("failed to write error response")
@@ -675,14 +986,39 @@ func Start(
rateLimiter.StartCleanup(st.GetContext())
r.Use(apimiddleware.HTTPRateLimitMiddleware(rateLimiter))
+
+ // Start global mutex cleanup routine with proper context
+ apimiddleware.StartGlobalMutexCleanup(st.GetContext())
+
+ // Start pairing session cleanup routine with proper context
+ StartPairingCleanup(st.GetContext())
+
+ // Generate random CSRF key once at startup for security
+ csrfKey := make([]byte, 32)
+ if _, err := rand.Read(csrfKey); err != nil {
+ log.Fatal().Err(err).Msg("failed to generate CSRF key")
+ }
+
+ // Security headers middleware - protect against common attacks
+ secureMiddleware := secure.New(secure.Options{
+ FrameDeny: true, // X-Frame-Options: DENY
+ ContentTypeNosniff: true, // X-Content-Type-Options: nosniff
+ BrowserXssFilter: true, // X-XSS-Protection: 1; mode=block
+ ContentSecurityPolicy: "default-src 'self'; connect-src 'self' ws: wss:; " +
+ "img-src 'self' data:; style-src 'self' 'unsafe-inline'",
+ ReferrerPolicy: "strict-origin-when-cross-origin",
+ })
+ r.Use(secureMiddleware.Handler)
+
r.Use(middleware.Recoverer)
r.Use(middleware.NoCache)
r.Use(middleware.Timeout(config.APIRequestTimeout))
r.Use(cors.Handler(cors.Options{
- AllowedOrigins: dynamicAllowedOrigins,
- AllowedMethods: []string{"GET", "POST", "OPTIONS"},
- AllowedHeaders: []string{"Accept", "Content-Type"},
- ExposedHeaders: []string{},
+ AllowedOrigins: dynamicAllowedOrigins,
+ AllowedMethods: []string{"GET", "POST", "OPTIONS"},
+ AllowedHeaders: []string{"Accept", "Content-Type", "X-CSRF-Token"},
+ ExposedHeaders: []string{"X-CSRF-Token"},
+ AllowCredentials: true, // Required for CSRF protection
}))
if strings.HasSuffix(config.AppVersion, "-dev") {
@@ -700,29 +1036,72 @@ func Start(
}
go broadcastNotifications(st, session, notifications)
- r.Get("/api", func(w http.ResponseWriter, r *http.Request) {
- err := session.HandleRequest(w, r)
- if err != nil {
- log.Error().Err(err).Msg("handling websocket request: latest")
- }
+ // Pairing endpoints with CSRF protection and enhanced rate limiting
+ r.Route("/api/pair", func(r chi.Router) {
+ // CSRF protection for pairing endpoints
+ r.Use(csrf.Protect(
+ csrfKey,
+ csrf.Secure(false), // Set to true when using HTTPS
+ csrf.HttpOnly(true),
+ csrf.SameSite(csrf.SameSiteStrictMode),
+ ))
+
+ // Enhanced rate limiting for pairing (more restrictive)
+ pairingLimiter := apimiddleware.NewIPRateLimiter()
+ pairingLimiter.StartCleanup(st.GetContext())
+ r.Use(apimiddleware.HTTPRateLimitMiddleware(pairingLimiter))
+
+ r.Post("/initiate", handlePairingInitiate(db))
+ r.Post("/complete", handlePairingComplete(db))
+
+ // CSRF token endpoint for clients
+ r.Get("/csrf-token", func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ token := csrf.Token(r)
+
+ response := CSRFTokenResponse{
+ CSRFToken: token,
+ }
+
+ if err := json.NewEncoder(w).Encode(response); err != nil {
+ log.Error().Err(err).Msg("Failed to write CSRF token response")
+ }
+ })
})
- r.Post("/api", handlePostRequest(methodMap, platform, cfg, st, inTokenQueue, db))
- r.Get("/api/v0", func(w http.ResponseWriter, r *http.Request) {
- err := session.HandleRequest(w, r)
- if err != nil {
- log.Error().Err(err).Msg("handling websocket request: v0")
- }
+ // Protected API routes with authentication middleware
+ r.Route("/api", func(r chi.Router) {
+ r.Use(apimiddleware.AuthMiddleware(db))
+ r.Get("/", func(w http.ResponseWriter, r *http.Request) {
+ err := session.HandleRequest(w, r)
+ if err != nil {
+ log.Error().Err(err).Msg("handling websocket request: latest")
+ }
+ })
+ r.Post("/", handlePostRequest(methodMap, platform, cfg, st, inTokenQueue, db))
})
- r.Post("/api/v0", handlePostRequest(methodMap, platform, cfg, st, inTokenQueue, db))
- r.Get("/api/v0.1", func(w http.ResponseWriter, r *http.Request) {
- err := session.HandleRequest(w, r)
- if err != nil {
- log.Error().Err(err).Msg("handling websocket request: v0.1")
- }
+ r.Route("/api/v0", func(r chi.Router) {
+ r.Use(apimiddleware.AuthMiddleware(db))
+ r.Get("/", func(w http.ResponseWriter, r *http.Request) {
+ err := session.HandleRequest(w, r)
+ if err != nil {
+ log.Error().Err(err).Msg("handling websocket request: v0")
+ }
+ })
+ r.Post("/", handlePostRequest(methodMap, platform, cfg, st, inTokenQueue, db))
+ })
+
+ r.Route("/api/v0.1", func(r chi.Router) {
+ r.Use(apimiddleware.AuthMiddleware(db))
+ r.Get("/", func(w http.ResponseWriter, r *http.Request) {
+ err := session.HandleRequest(w, r)
+ if err != nil {
+ log.Error().Err(err).Msg("handling websocket request: v0.1")
+ }
+ })
+ r.Post("/", handlePostRequest(methodMap, platform, cfg, st, inTokenQueue, db))
})
- r.Post("/api/v0.1", handlePostRequest(methodMap, platform, cfg, st, inTokenQueue, db))
session.HandleMessage(apimiddleware.WebSocketRateLimitHandler(
rateLimiter,
diff --git a/pkg/cli/cli.go b/pkg/cli/cli.go
index 7fcdc2b13..0ee72bb80 100644
--- a/pkg/cli/cli.go
+++ b/pkg/cli/cli.go
@@ -40,16 +40,19 @@ import (
)
type Flags struct {
- Write *string
- Read *bool
- Run *string
- Launch *string
- API *string
- Version *bool
- Config *bool
- ShowLoader *string
- ShowPicker *string
- Reload *bool
+ Write *string
+ Read *bool
+ Run *string
+ Launch *string
+ API *string
+ Version *bool
+ Config *bool
+ ShowLoader *string
+ ShowPairingCode *bool
+ ListClients *bool
+ RevokeClient *string
+ ShowPicker *string
+ Reload *bool
}
// SetupFlags defines all common CLI flags between platforms.
@@ -95,6 +98,21 @@ func SetupFlags() *Flags {
false,
"reload config and mappings from disk",
),
+ ShowPairingCode: flag.Bool(
+ "show-pairing-code",
+ false,
+ "display QR code for client pairing",
+ ),
+ ListClients: flag.Bool(
+ "list-clients",
+ false,
+ "list all paired clients",
+ ),
+ RevokeClient: flag.String(
+ "revoke-client",
+ "",
+ "revoke access for client by ID",
+ ),
}
}
@@ -139,8 +157,17 @@ func runFlag(cfg *config.Instance, value string) {
// Post actions all remaining common flags that require the environment to be
// set up. Logging is allowed.
-func (f *Flags) Post(cfg *config.Instance, _ platforms.Platform) {
+func (f *Flags) Post(cfg *config.Instance, pl platforms.Platform) {
switch {
+ case *f.ShowPairingCode:
+ handleShowPairingCode(cfg, pl)
+ os.Exit(0)
+ case *f.ListClients:
+ handleListClients(cfg, pl)
+ os.Exit(0)
+ case isFlagPassed("revoke-client"):
+ handleRevokeClient(cfg, pl, *f.RevokeClient)
+ os.Exit(0)
case isFlagPassed("write"):
if *f.Write == "" {
_, _ = fmt.Fprint(os.Stderr, "Error: write flag requires a value\n")
diff --git a/pkg/cli/clients.go b/pkg/cli/clients.go
new file mode 100644
index 000000000..09251e97b
--- /dev/null
+++ b/pkg/cli/clients.go
@@ -0,0 +1,213 @@
+// Zaparoo Core
+// Copyright (c) 2025 The Zaparoo Project Contributors.
+// SPDX-License-Identifier: GPL-3.0-or-later
+//
+// This file is part of Zaparoo Core.
+//
+// Zaparoo Core is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// Zaparoo Core is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with Zaparoo Core. If not, see .
+
+package cli
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "os"
+ "strings"
+ "time"
+
+ "github.com/ZaparooProject/zaparoo-core/v2/pkg/api"
+ "github.com/ZaparooProject/zaparoo-core/v2/pkg/config"
+ "github.com/ZaparooProject/zaparoo-core/v2/pkg/database/userdb"
+ "github.com/ZaparooProject/zaparoo-core/v2/pkg/platforms"
+ "github.com/rs/zerolog/log"
+)
+
+func generateQRCode(data string) {
+ // Simple ASCII QR code placeholder - in a real implementation, you'd use a QR code library
+ if _, err := fmt.Print("\n=== PAIRING QR CODE ===\n"); err != nil {
+ log.Error().Err(err).Msg("failed to print header")
+ }
+ if _, err := fmt.Print("Scan this with your mobile app:\n\n"); err != nil {
+ log.Error().Err(err).Msg("failed to print instruction")
+ }
+
+ // For now, just display the data - a real implementation would generate ASCII QR code
+ if _, err := fmt.Printf("Data: %s\n\n", data); err != nil {
+ log.Error().Err(err).Msg("failed to print data")
+ }
+
+ // You could use a library like github.com/skip2/go-qrcode for actual QR generation
+ note := "Note: QR code display not yet implemented - use manual pairing with the data above\n"
+ if _, err := fmt.Print(note); err != nil {
+ log.Error().Err(err).Msg("failed to print note")
+ }
+ if _, err := fmt.Print("======================\n\n"); err != nil {
+ log.Error().Err(err).Msg("failed to print footer")
+ }
+}
+
+func handleShowPairingCode(cfg *config.Instance, pl platforms.Platform) {
+ // Open user database to check pairing sessions
+ userDB, err := userdb.OpenUserDB(context.Background(), pl)
+ if err != nil {
+ log.Error().Err(err).Msg("failed to open user database")
+ if _, writeErr := fmt.Fprintf(os.Stderr, "Error opening user database: %v\n", err); writeErr != nil {
+ log.Error().Err(writeErr).Msg("failed to write error message")
+ }
+ os.Exit(1)
+ }
+ defer func() { _ = userDB.Close() }()
+
+ // Create HTTP client to call our own pairing API
+ client := &http.Client{Timeout: 10 * time.Second}
+
+ // Call pairing initiate endpoint
+ apiURL := fmt.Sprintf("http://localhost:%d/api/pair/initiate", cfg.APIPort())
+ req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, apiURL, strings.NewReader("{}"))
+ if err != nil {
+ log.Error().Err(err).Msg("failed to create pairing request")
+ if _, writeErr := fmt.Fprintf(os.Stderr, "Error creating request: %v\n", err); writeErr != nil {
+ log.Error().Err(writeErr).Msg("failed to write error message")
+ }
+ return
+ }
+ req.Header.Set("Content-Type", "application/json")
+ resp, err := client.Do(req)
+ if err != nil {
+ log.Error().Err(err).Msg("failed to initiate pairing")
+ if _, writeErr := fmt.Fprintf(os.Stderr, "Error initiating pairing: %v\n", err); writeErr != nil {
+ log.Error().Err(writeErr).Msg("failed to write error message")
+ }
+ return
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ var pairingResp api.PairingInitiateResponse
+ if err := json.NewDecoder(resp.Body).Decode(&pairingResp); err != nil {
+ log.Error().Err(err).Msg("failed to decode pairing response")
+ if _, writeErr := fmt.Fprintf(os.Stderr, "Error decoding response: %v\n", err); writeErr != nil {
+ log.Error().Err(writeErr).Msg("failed to write error message")
+ }
+ return
+ }
+
+ // Generate QR code data
+ serverAddr := fmt.Sprintf("localhost:%d", cfg.APIPort())
+ qrData := api.QRCodeData{
+ Address: serverAddr,
+ Token: pairingResp.PairingToken,
+ }
+
+ jsonData, _ := json.Marshal(qrData)
+ generateQRCode(string(jsonData))
+
+ if _, err := fmt.Printf("Pairing token: %s\n", pairingResp.PairingToken); err != nil {
+ log.Error().Err(err).Msg("failed to print pairing token")
+ }
+ if _, err := fmt.Printf("Expires in: %d seconds\n", pairingResp.ExpiresIn); err != nil {
+ log.Error().Err(err).Msg("failed to print expiration time")
+ }
+ if _, err := fmt.Print("\nWaiting for client to pair... (Ctrl+C to cancel)\n"); err != nil {
+ log.Error().Err(err).Msg("failed to print waiting message")
+ }
+
+ // Wait for user to cancel
+ select {}
+}
+
+func handleListClients(_ *config.Instance, pl platforms.Platform) {
+ // Open user database to list clients
+ userDB, err := userdb.OpenUserDB(context.Background(), pl)
+ if err != nil {
+ log.Error().Err(err).Msg("failed to open user database")
+ if _, writeErr := fmt.Fprintf(os.Stderr, "Error opening user database: %v\n", err); writeErr != nil {
+ log.Error().Err(writeErr).Msg("failed to write error message")
+ }
+ os.Exit(1)
+ }
+ defer func() { _ = userDB.Close() }()
+
+ clients, err := userDB.GetAllClients()
+ if err != nil {
+ log.Error().Err(err).Msg("failed to get clients")
+ if _, writeErr := fmt.Fprintf(os.Stderr, "Error getting clients: %v\n", err); writeErr != nil {
+ log.Error().Err(writeErr).Msg("failed to write error message")
+ }
+ return
+ }
+
+ if len(clients) == 0 {
+ if _, err := fmt.Println("No paired clients found."); err != nil {
+ log.Error().Err(err).Msg("failed to print message")
+ }
+ return
+ }
+
+ if _, err := fmt.Print("Paired clients:\n\n"); err != nil {
+ log.Error().Err(err).Msg("failed to print header")
+ }
+ if _, err := fmt.Printf("%-36s %-20s %-10s %s\n", "Client ID", "Name", "Sequence", "Last Seen"); err != nil {
+ log.Error().Err(err).Msg("failed to print column headers")
+ }
+ if _, err := fmt.Printf("%s\n", strings.Repeat("-", 80)); err != nil {
+ log.Error().Err(err).Msg("failed to print separator")
+ }
+
+ for i := range clients {
+ client := &clients[i]
+ if _, err := fmt.Printf("%-36s %-20s %-10d %s\n",
+ client.ClientID,
+ client.ClientName,
+ client.CurrentSeq,
+ client.LastSeen.Format("2006-01-02 15:04:05"),
+ ); err != nil {
+ log.Error().Err(err).Msg("failed to print client info")
+ }
+ }
+}
+
+func handleRevokeClient(_ *config.Instance, pl platforms.Platform, clientID string) {
+ if clientID == "" {
+ if _, err := fmt.Fprint(os.Stderr, "Error: client ID is required\n"); err != nil {
+ log.Error().Err(err).Msg("failed to write error message")
+ }
+ os.Exit(1)
+ }
+
+ // Open user database to revoke client
+ userDB, err := userdb.OpenUserDB(context.Background(), pl)
+ if err != nil {
+ log.Error().Err(err).Msg("failed to open user database")
+ if _, writeErr := fmt.Fprintf(os.Stderr, "Error opening user database: %v\n", err); writeErr != nil {
+ log.Error().Err(writeErr).Msg("failed to write error message")
+ }
+ os.Exit(1)
+ }
+ defer func() { _ = userDB.Close() }()
+
+ err = userDB.DeleteClient(clientID)
+ if err != nil {
+ log.Error().Err(err).Msg("failed to delete client")
+ if _, writeErr := fmt.Fprintf(os.Stderr, "Error deleting client: %v\n", err); writeErr != nil {
+ log.Error().Err(writeErr).Msg("failed to write error message")
+ }
+ return
+ }
+
+ if _, err := fmt.Printf("Client %s has been revoked successfully.\n", clientID); err != nil {
+ log.Error().Err(err).Msg("failed to print success message")
+ }
+}
diff --git a/pkg/config/config.go b/pkg/config/config.go
index 8bd6ab319..b1baca1f2 100644
--- a/pkg/config/config.go
+++ b/pkg/config/config.go
@@ -359,6 +359,12 @@ func (c *Instance) AllowedOrigins() []string {
return c.vals.Service.AllowedOrigins
}
+func (c *Instance) DeviceID() string {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+ return c.vals.Service.DeviceID
+}
+
func (c *Instance) IsExecuteAllowed(s string) bool {
c.mu.RLock()
defer c.mu.RUnlock()
diff --git a/pkg/database/database.go b/pkg/database/database.go
index ec7961a9c..73f9a4a83 100644
--- a/pkg/database/database.go
+++ b/pkg/database/database.go
@@ -69,6 +69,18 @@ type System struct {
DBID int64
}
+type Client struct {
+ CreatedAt time.Time `json:"createdAt"`
+ LastSeen time.Time `json:"lastSeen"`
+ ClientID string `json:"clientId"`
+ ClientName string `json:"clientName"`
+ AuthTokenHash string `json:"-"`
+ SharedSecret []byte `json:"-"`
+ SeqWindow []byte `json:"-"`
+ NonceCache []string `json:"-"`
+ CurrentSeq uint64 `json:"currentSeq"`
+}
+
type MediaTitle struct {
Slug string
Name string
@@ -154,6 +166,12 @@ type UserDBI interface {
GetZapLinkHost(host string) (bool, bool, error)
UpdateZapLinkCache(url string, zapscript string) error
GetZapLinkCache(url string) (string, error)
+ CreateClient(clientName, authToken string, sharedSecret []byte) (*Client, error)
+ GetClientByAuthToken(authToken string) (*Client, error)
+ GetClientByID(clientID string) (*Client, error)
+ UpdateClientSequence(clientID string, newSeq uint64, seqWindow []byte, nonceCache []string) error
+ GetAllClients() ([]Client, error)
+ DeleteClient(clientID string) error
}
type MediaDBI interface {
diff --git a/pkg/database/userdb/clients.go b/pkg/database/userdb/clients.go
new file mode 100644
index 000000000..a8c7ad50d
--- /dev/null
+++ b/pkg/database/userdb/clients.go
@@ -0,0 +1,280 @@
+// Zaparoo Core
+// Copyright (c) 2025 The Zaparoo Project Contributors.
+// SPDX-License-Identifier: GPL-3.0-or-later
+//
+// This file is part of Zaparoo Core.
+//
+// Zaparoo Core is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// Zaparoo Core is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with Zaparoo Core. If not, see .
+
+package userdb
+
+import (
+ "context"
+ "crypto/sha256"
+ "encoding/hex"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "time"
+
+ "github.com/ZaparooProject/zaparoo-core/v2/pkg/database"
+ "github.com/google/uuid"
+)
+
+func (db *UserDB) CreateClient(clientName, authToken string, sharedSecret []byte) (*database.Client, error) {
+ if db.sql == nil {
+ return nil, ErrNullSQL
+ }
+
+ clientID := uuid.New().String()
+ authTokenHash := hashAuthToken(authToken)
+ now := time.Now().Unix()
+
+ client := &database.Client{
+ ClientID: clientID,
+ ClientName: clientName,
+ AuthTokenHash: authTokenHash,
+ SharedSecret: sharedSecret,
+ CurrentSeq: 0,
+ SeqWindow: make([]byte, 8), // 64-bit window
+ NonceCache: make([]string, 0),
+ CreatedAt: time.Unix(now, 0),
+ LastSeen: time.Unix(now, 0),
+ }
+
+ nonceCacheJSON, err := json.Marshal(client.NonceCache)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal nonce cache: %w", err)
+ }
+
+ query := `
+ INSERT INTO clients (client_id, client_name, auth_token_hash, shared_secret,
+ current_seq, seq_window, nonce_cache, created_at, last_seen)
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
+ `
+
+ _, err = db.sql.ExecContext(context.Background(), query,
+ client.ClientID,
+ client.ClientName,
+ client.AuthTokenHash,
+ client.SharedSecret,
+ client.CurrentSeq,
+ client.SeqWindow,
+ string(nonceCacheJSON),
+ now,
+ now,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("failed to create client: %w", err)
+ }
+
+ return client, nil
+}
+
+func (db *UserDB) GetClientByAuthToken(authToken string) (*database.Client, error) {
+ if db.sql == nil {
+ return nil, ErrNullSQL
+ }
+
+ authTokenHash := hashAuthToken(authToken)
+
+ query := `
+ SELECT client_id, client_name, auth_token_hash, shared_secret, current_seq,
+ seq_window, nonce_cache, created_at, last_seen
+ FROM clients
+ WHERE auth_token_hash = ?
+ `
+
+ var client database.Client
+ var nonceCacheJSON string
+ var createdAt, lastSeen int64
+
+ err := db.sql.QueryRowContext(context.Background(), query, authTokenHash).Scan(
+ &client.ClientID,
+ &client.ClientName,
+ &client.AuthTokenHash,
+ &client.SharedSecret,
+ &client.CurrentSeq,
+ &client.SeqWindow,
+ &nonceCacheJSON,
+ &createdAt,
+ &lastSeen,
+ )
+ if err != nil {
+ // Use a generic error to avoid leaking information about whether the token exists
+ return nil, errors.New("invalid credentials")
+ }
+
+ client.CreatedAt = time.Unix(createdAt, 0)
+ client.LastSeen = time.Unix(lastSeen, 0)
+
+ if unmarshalErr := json.Unmarshal([]byte(nonceCacheJSON), &client.NonceCache); unmarshalErr != nil {
+ client.NonceCache = make([]string, 0) // Fallback to empty cache
+ }
+
+ return &client, nil
+}
+
+func (db *UserDB) GetClientByID(clientID string) (*database.Client, error) {
+ if db.sql == nil {
+ return nil, ErrNullSQL
+ }
+
+ query := `
+ SELECT client_id, client_name, auth_token_hash, shared_secret, current_seq,
+ seq_window, nonce_cache, created_at, last_seen
+ FROM clients
+ WHERE client_id = ?
+ `
+
+ var client database.Client
+ var nonceCacheJSON string
+ var createdAt, lastSeen int64
+
+ err := db.sql.QueryRowContext(context.Background(), query, clientID).Scan(
+ &client.ClientID,
+ &client.ClientName,
+ &client.AuthTokenHash,
+ &client.SharedSecret,
+ &client.CurrentSeq,
+ &client.SeqWindow,
+ &nonceCacheJSON,
+ &createdAt,
+ &lastSeen,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("client not found: %w", err)
+ }
+
+ client.CreatedAt = time.Unix(createdAt, 0)
+ client.LastSeen = time.Unix(lastSeen, 0)
+
+ err = json.Unmarshal([]byte(nonceCacheJSON), &client.NonceCache)
+ if err != nil {
+ client.NonceCache = make([]string, 0) // Fallback to empty cache
+ }
+
+ return &client, nil
+}
+
+func (db *UserDB) UpdateClientSequence(clientID string, newSeq uint64, seqWindow []byte, nonceCache []string) error {
+ if db.sql == nil {
+ return ErrNullSQL
+ }
+
+ nonceCacheJSON, err := json.Marshal(nonceCache)
+ if err != nil {
+ return fmt.Errorf("failed to marshal nonce cache: %w", err)
+ }
+
+ query := `
+ UPDATE clients
+ SET current_seq = ?, seq_window = ?, nonce_cache = ?, last_seen = ?
+ WHERE client_id = ?
+ `
+
+ _, err = db.sql.ExecContext(
+ context.Background(), query, newSeq, seqWindow, string(nonceCacheJSON), time.Now().Unix(), clientID,
+ )
+ if err != nil {
+ return fmt.Errorf("failed to update client sequence: %w", err)
+ }
+
+ return nil
+}
+
+func (db *UserDB) GetAllClients() ([]database.Client, error) {
+ if db.sql == nil {
+ return nil, ErrNullSQL
+ }
+
+ query := `
+ SELECT client_id, client_name, auth_token_hash, shared_secret, current_seq,
+ seq_window, nonce_cache, created_at, last_seen
+ FROM clients
+ ORDER BY last_seen DESC
+ `
+
+ rows, err := db.sql.QueryContext(context.Background(), query)
+ if err != nil {
+ return nil, fmt.Errorf("failed to query clients: %w", err)
+ }
+ defer func() { _ = rows.Close() }()
+
+ clients := make([]database.Client, 0)
+ for rows.Next() {
+ var client database.Client
+ var nonceCacheJSON string
+ var createdAt, lastSeen int64
+
+ scanErr := rows.Scan(
+ &client.ClientID,
+ &client.ClientName,
+ &client.AuthTokenHash,
+ &client.SharedSecret,
+ &client.CurrentSeq,
+ &client.SeqWindow,
+ &nonceCacheJSON,
+ &createdAt,
+ &lastSeen,
+ )
+ if scanErr != nil {
+ return nil, fmt.Errorf("failed to scan client row: %w", scanErr)
+ }
+
+ client.CreatedAt = time.Unix(createdAt, 0)
+ client.LastSeen = time.Unix(lastSeen, 0)
+
+ err = json.Unmarshal([]byte(nonceCacheJSON), &client.NonceCache)
+ if err != nil {
+ client.NonceCache = make([]string, 0) // Fallback to empty cache
+ }
+
+ clients = append(clients, client)
+ }
+
+ if err = rows.Err(); err != nil {
+ return nil, fmt.Errorf("error reading client rows: %w", err)
+ }
+
+ return clients, nil
+}
+
+func (db *UserDB) DeleteClient(clientID string) error {
+ if db.sql == nil {
+ return ErrNullSQL
+ }
+
+ query := `DELETE FROM clients WHERE client_id = ?`
+ result, err := db.sql.ExecContext(context.Background(), query, clientID)
+ if err != nil {
+ return fmt.Errorf("failed to delete client: %w", err)
+ }
+
+ rowsAffected, err := result.RowsAffected()
+ if err != nil {
+ return fmt.Errorf("failed to get rows affected: %w", err)
+ }
+
+ if rowsAffected == 0 {
+ return errors.New("client not found")
+ }
+
+ return nil
+}
+
+func hashAuthToken(authToken string) string {
+ hash := sha256.Sum256([]byte(authToken))
+ return hex.EncodeToString(hash[:])
+}
diff --git a/pkg/database/userdb/migrations/20250912231204_clients_auth.sql b/pkg/database/userdb/migrations/20250912231204_clients_auth.sql
new file mode 100644
index 000000000..b3380c6d7
--- /dev/null
+++ b/pkg/database/userdb/migrations/20250912231204_clients_auth.sql
@@ -0,0 +1,22 @@
+-- +goose Up
+-- +goose StatementBegin
+CREATE TABLE clients (
+ client_id TEXT PRIMARY KEY,
+ client_name TEXT NOT NULL,
+ auth_token_hash TEXT NOT NULL UNIQUE,
+ shared_secret BLOB NOT NULL,
+ current_seq INTEGER DEFAULT 0,
+ seq_window BLOB,
+ nonce_cache TEXT DEFAULT '[]',
+ created_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now')),
+ last_seen INTEGER NOT NULL DEFAULT (strftime('%s', 'now'))
+);
+
+CREATE INDEX idx_clients_auth_token ON clients(auth_token_hash);
+CREATE INDEX idx_clients_last_seen ON clients(last_seen);
+-- +goose StatementEnd
+
+-- +goose Down
+-- +goose StatementBegin
+DROP TABLE clients;
+-- +goose StatementEnd
\ No newline at end of file
diff --git a/pkg/testing/helpers/db_mocks.go b/pkg/testing/helpers/db_mocks.go
index b95b4e178..2607e8d64 100644
--- a/pkg/testing/helpers/db_mocks.go
+++ b/pkg/testing/helpers/db_mocks.go
@@ -57,6 +57,11 @@ import (
"github.com/stretchr/testify/mock"
)
+// Sentinel errors for mock functions
+var (
+ ErrMockNotConfigured = errors.New("mock not configured for this call")
+)
+
// MockUserDBI is a mock implementation of the UserDBI interface using testify/mock
type MockUserDBI struct {
mock.Mock
@@ -239,6 +244,81 @@ func (m *MockUserDBI) GetZapLinkCache(url string) (string, error) {
return args.String(0), args.Error(1)
}
+// Client authentication methods
+func (m *MockUserDBI) CreateClient(clientName, authToken string, sharedSecret []byte) (*database.Client, error) {
+ args := m.Called(clientName, authToken, sharedSecret)
+ if client, ok := args.Get(0).(*database.Client); ok {
+ if err := args.Error(1); err != nil {
+ return client, fmt.Errorf("mock UserDBI create client failed: %w", err)
+ }
+ return client, nil
+ }
+ if err := args.Error(1); err != nil {
+ return nil, fmt.Errorf("mock UserDBI create client failed: %w", err)
+ }
+ return nil, ErrMockNotConfigured
+}
+
+func (m *MockUserDBI) GetClientByAuthToken(authToken string) (*database.Client, error) {
+ args := m.Called(authToken)
+ if client, ok := args.Get(0).(*database.Client); ok {
+ if err := args.Error(1); err != nil {
+ return client, fmt.Errorf("mock UserDBI get client by auth token failed: %w", err)
+ }
+ return client, nil
+ }
+ if err := args.Error(1); err != nil {
+ return nil, fmt.Errorf("mock UserDBI get client by auth token failed: %w", err)
+ }
+ return nil, ErrMockNotConfigured
+}
+
+func (m *MockUserDBI) GetClientByID(clientID string) (*database.Client, error) {
+ args := m.Called(clientID)
+ if client, ok := args.Get(0).(*database.Client); ok {
+ if err := args.Error(1); err != nil {
+ return client, fmt.Errorf("mock UserDBI get client by ID failed: %w", err)
+ }
+ return client, nil
+ }
+ if err := args.Error(1); err != nil {
+ return nil, fmt.Errorf("mock UserDBI get client by ID failed: %w", err)
+ }
+ return nil, ErrMockNotConfigured
+}
+
+func (m *MockUserDBI) UpdateClientSequence(
+ clientID string, newSeq uint64, seqWindow []byte, nonceCache []string,
+) error {
+ args := m.Called(clientID, newSeq, seqWindow, nonceCache)
+ if err := args.Error(0); err != nil {
+ return fmt.Errorf("mock UserDBI update client sequence failed: %w", err)
+ }
+ return nil
+}
+
+func (m *MockUserDBI) GetAllClients() ([]database.Client, error) {
+ args := m.Called()
+ if clients, ok := args.Get(0).([]database.Client); ok {
+ if err := args.Error(1); err != nil {
+ return clients, fmt.Errorf("mock UserDBI get all clients failed: %w", err)
+ }
+ return clients, nil
+ }
+ if err := args.Error(1); err != nil {
+ return nil, fmt.Errorf("mock UserDBI get all clients failed: %w", err)
+ }
+ return nil, nil
+}
+
+func (m *MockUserDBI) DeleteClient(clientID string) error {
+ args := m.Called(clientID)
+ if err := args.Error(0); err != nil {
+ return fmt.Errorf("mock UserDBI delete client failed: %w", err)
+ }
+ return nil
+}
+
// MockMediaDBI is a mock implementation of the MediaDBI interface using testify/mock
type MockMediaDBI struct {
mock.Mock