From 83210eeed9703eaae2b213d14aac912585c32a21 Mon Sep 17 00:00:00 2001 From: Callan Barrett Date: Sat, 13 Sep 2025 08:00:53 +0800 Subject: [PATCH 1/5] initial encryption work --- pkg/api/middleware/auth.go | 489 ++++++++++++++++ pkg/api/middleware/auth_test.go | 523 ++++++++++++++++++ pkg/api/middleware/ratelimit_test.go | 206 +++++++ pkg/api/pairing.go | 271 +++++++++ pkg/api/server.go | 307 +++++++++- pkg/cli/cli.go | 49 +- pkg/cli/devices.go | 213 +++++++ pkg/database/database.go | 18 + pkg/database/userdb/devices.go | 280 ++++++++++ .../20250912231204_devices_auth.sql | 22 + pkg/testing/helpers/db_mocks.go | 80 +++ 11 files changed, 2422 insertions(+), 36 deletions(-) create mode 100644 pkg/api/middleware/auth.go create mode 100644 pkg/api/middleware/auth_test.go create mode 100644 pkg/api/middleware/ratelimit_test.go create mode 100644 pkg/api/pairing.go create mode 100644 pkg/cli/devices.go create mode 100644 pkg/database/userdb/devices.go create mode 100644 pkg/database/userdb/migrations/20250912231204_devices_auth.sql diff --git a/pkg/api/middleware/auth.go b/pkg/api/middleware/auth.go new file mode 100644 index 000000000..50150c858 --- /dev/null +++ b/pkg/api/middleware/auth.go @@ -0,0 +1,489 @@ +// Zaparoo Core +// Copyright (c) 2025 The Zaparoo Project Contributors. +// SPDX-License-Identifier: GPL-3.0-or-later +// +// This file is part of Zaparoo Core. +// +// Zaparoo Core is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Zaparoo Core is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Zaparoo Core. If not, see . + +package middleware + +import ( + "bytes" + "context" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "slices" + "sync" + "time" + + "github.com/ZaparooProject/zaparoo-core/v2/pkg/database" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/database/userdb" + "github.com/rs/zerolog/log" +) + +const ( + SequenceWindow = 64 // Size of sliding window for sequence numbers + NonceCacheSize = 100 // Maximum number of cached nonces + MutexCleanupInterval = 10 * time.Minute // Cleanup unused mutexes every 10 minutes + MutexMaxIdle = 30 * time.Minute // Remove mutexes unused for 30 minutes +) + +type deviceKey string + +// DeviceMutexManager handles per-device locking to prevent race conditions +// in authentication state updates +type DeviceMutexManager struct { + mutexes sync.Map // map[string]*deviceMutex +} + +type deviceMutex struct { + mu sync.Mutex + lastUsed time.Time + deviceID string +} + +var globalDeviceMutexManager = &DeviceMutexManager{} + +type EncryptedRequest struct { + Encrypted string `json:"encrypted"` + IV string `json:"iv"` + AuthToken string `json:"authToken"` +} + +type DecryptedPayload struct { + ID any `json:"id,omitempty"` + JSONRPC string `json:"jsonrpc"` + Method string `json:"method"` + Nonce string `json:"nonce"` + Params json.RawMessage `json:"params,omitempty"` + Seq uint64 `json:"seq"` +} + +func isLocalhost(remoteAddr string) bool { + host, _, err := net.SplitHostPort(remoteAddr) + if err != nil { + host = remoteAddr + } + + ip := net.ParseIP(host) + if ip == nil { + return host == "localhost" + } + + return ip.IsLoopback() +} + +func AuthMiddleware(db *database.Database) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Skip authentication for localhost connections + if isLocalhost(r.RemoteAddr) { + log.Debug().Str("remote_addr", r.RemoteAddr).Msg("localhost connection - skipping auth") + next.ServeHTTP(w, r) + return + } + + // Read request body + body, err := io.ReadAll(r.Body) + if err != nil { + log.Error().Err(err).Msg("failed to read request body") + http.Error(w, "Failed to read request body", http.StatusInternalServerError) + return + } + + // Try to parse as encrypted request + var encReq EncryptedRequest + if parseErr := json.Unmarshal(body, &encReq); parseErr != nil { + log.Error().Err(parseErr).Msg("invalid encrypted request format") + http.Error(w, "Invalid request format", http.StatusBadRequest) + return + } + + // Validate auth token and get device + device, err := db.UserDB.GetDeviceByAuthToken(encReq.AuthToken) + if err != nil { + tokenStr := "empty" + if len(encReq.AuthToken) >= 8 { + tokenStr = encReq.AuthToken[:8] + "..." + } else if encReq.AuthToken != "" { + tokenStr = encReq.AuthToken + } + log.Error().Err(err).Str("token", tokenStr).Msg("invalid auth token") + http.Error(w, "Invalid auth token", http.StatusUnauthorized) + return + } + + // Decrypt payload + decryptedPayload, err := DecryptPayload(encReq.Encrypted, encReq.IV, device.SharedSecret) + if err != nil { + log.Error().Err(err).Msg("failed to decrypt payload") + http.Error(w, "Decryption failed", http.StatusBadRequest) + return + } + + // Parse decrypted payload + var payload DecryptedPayload + if parseErr := json.Unmarshal(decryptedPayload, &payload); parseErr != nil { + log.Error().Err(parseErr).Msg("invalid decrypted payload format") + http.Error(w, "Invalid payload format", http.StatusBadRequest) + return + } + + // CRITICAL SECTION: Acquire device lock to prevent race conditions + // between validation and database update + unlockDevice := LockDevice(device.DeviceID) + defer unlockDevice() + + // Re-fetch device state under lock to get latest sequence/nonce state + freshDevice, err := db.UserDB.GetDeviceByAuthToken(encReq.AuthToken) + if err != nil { + log.Error().Err(err).Msg("failed to re-fetch device under lock") + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + // Validate sequence number and nonce with fresh device state + if !ValidateSequenceAndNonce(freshDevice, payload.Seq, payload.Nonce) { + log.Warn(). + Str("device_id", freshDevice.DeviceID). + Uint64("seq", payload.Seq). + Str("nonce", payload.Nonce). + Msg("invalid sequence or replay attack detected") + http.Error(w, "Invalid sequence or replay detected", http.StatusBadRequest) + return + } + + // Update device state (sequence and nonce cache) + updatedDevice := *freshDevice + updateDeviceSequenceAndNonce(&updatedDevice, payload.Seq, payload.Nonce) + + // Save to database (still under lock) + if updateErr := db.UserDB.UpdateDeviceSequence( + updatedDevice.DeviceID, updatedDevice.CurrentSeq, updatedDevice.SeqWindow, updatedDevice.NonceCache, + ); updateErr != nil { + log.Error().Err(updateErr).Msg("failed to update device state") + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + // Update the device pointer for context (use fresh device with updates) + device = &updatedDevice + + // Replace request body with decrypted JSON-RPC payload + originalPayload := map[string]any{ + "jsonrpc": payload.JSONRPC, + "method": payload.Method, + "id": payload.ID, + } + if payload.Params != nil { + originalPayload["params"] = payload.Params + } + + newBody, err := json.Marshal(originalPayload) + if err != nil { + log.Error().Err(err).Msg("failed to marshal decrypted payload") + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + // Create new request with decrypted body + r.Body = io.NopCloser(bytes.NewReader(newBody)) + r.ContentLength = int64(len(newBody)) + + // Store device in context for potential use by handlers + ctx := context.WithValue(r.Context(), deviceKey("device"), device) + r = r.WithContext(ctx) + + log.Debug(). + Str("device_id", device.DeviceID). + Str("method", payload.Method). + Uint64("seq", payload.Seq). + Msg("authenticated request processed") + + next.ServeHTTP(w, r) + }) + } +} + +func DecryptPayload(encryptedB64, ivB64 string, key []byte) ([]byte, error) { + // Decode base64 + encrypted, err := base64.StdEncoding.DecodeString(encryptedB64) + if err != nil { + return nil, fmt.Errorf("invalid encrypted data: %w", err) + } + + iv, err := base64.StdEncoding.DecodeString(ivB64) + if err != nil { + return nil, fmt.Errorf("invalid IV: %w", err) + } + + // Create AES cipher + block, err := aes.NewCipher(key) + if err != nil { + return nil, fmt.Errorf("failed to create cipher: %w", err) + } + + // Create GCM mode + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, fmt.Errorf("failed to create GCM: %w", err) + } + + // Decrypt + plaintext, err := gcm.Open(nil, iv, encrypted, nil) + if err != nil { + return nil, fmt.Errorf("decryption failed: %w", err) + } + + return plaintext, nil +} + +func EncryptPayload(data, key []byte) (encrypted, iv string, err error) { + // Create AES cipher + block, err := aes.NewCipher(key) + if err != nil { + return "", "", fmt.Errorf("failed to create cipher: %w", err) + } + + // Create GCM mode + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", "", fmt.Errorf("failed to create GCM: %w", err) + } + + // Generate random IV + ivBytes := make([]byte, gcm.NonceSize()) + if _, err := io.ReadFull(rand.Reader, ivBytes); err != nil { + return "", "", fmt.Errorf("failed to generate IV: %w", err) + } + + // Encrypt + ciphertext := gcm.Seal(nil, ivBytes, data, nil) + + // Return base64 encoded values + return base64.StdEncoding.EncodeToString(ciphertext), + base64.StdEncoding.EncodeToString(ivBytes), nil +} + +func ValidateSequenceAndNonce(device *database.Device, seq uint64, nonce string) bool { + // Check if nonce was recently used (replay protection) + if slices.Contains(device.NonceCache, nonce) { + return false + } + + // Validate sequence number with sliding window + if seq <= device.CurrentSeq { + // Check if sequence is within acceptable window + diff := device.CurrentSeq - seq + if diff >= SequenceWindow { + return false // Too old + } + + // Check if this sequence was already processed (using seq_window bitmap) + windowPos := diff % SequenceWindow + bytePos := windowPos / 8 + bitPos := windowPos % 8 + + if bytePos < uint64(len(device.SeqWindow)) { + if (device.SeqWindow[bytePos] & (1 << bitPos)) != 0 { + return false // Already processed + } + } + } + + return true +} + +func updateDeviceSequenceAndNonce(device *database.Device, seq uint64, nonce string) { + // Update nonce cache (keep last NonceCacheSize nonces) + device.NonceCache = append(device.NonceCache, nonce) + if len(device.NonceCache) > NonceCacheSize { + device.NonceCache = device.NonceCache[1:] // Remove oldest + } + + // Update sequence window + if seq > device.CurrentSeq { + // New highest sequence - shift window + shift := seq - device.CurrentSeq + if shift >= SequenceWindow { + // Clear entire window + device.SeqWindow = make([]byte, 8) + } else { + // Shift window right + for range shift { + shiftWindowRight(device.SeqWindow) + } + } + device.CurrentSeq = seq + + // Mark current sequence as processed (position 0 in window) + device.SeqWindow[0] |= 1 + } else { + // Mark this sequence as processed in the window + diff := device.CurrentSeq - seq + windowPos := diff % SequenceWindow + bytePos := windowPos / 8 + bitPos := windowPos % 8 + + if bytePos < uint64(len(device.SeqWindow)) { + device.SeqWindow[bytePos] |= (1 << bitPos) + } + } +} + +func UpdateDeviceState(userDB *userdb.UserDB, device *database.Device, seq uint64, nonce string) error { + // Update nonce cache (keep last NonceCacheSize nonces) + device.NonceCache = append(device.NonceCache, nonce) + if len(device.NonceCache) > NonceCacheSize { + device.NonceCache = device.NonceCache[1:] // Remove oldest + } + + // Update sequence window + if seq > device.CurrentSeq { + // New highest sequence - shift window + shift := seq - device.CurrentSeq + if shift >= SequenceWindow { + // Clear entire window + device.SeqWindow = make([]byte, 8) + } else { + // Shift window right + for range shift { + shiftWindowRight(device.SeqWindow) + } + } + device.CurrentSeq = seq + + // Mark current sequence as processed (position 0 in window) + device.SeqWindow[0] |= 1 + } else { + // Mark this sequence as processed in the window + diff := device.CurrentSeq - seq + windowPos := diff % SequenceWindow + bytePos := windowPos / 8 + bitPos := windowPos % 8 + + if bytePos < uint64(len(device.SeqWindow)) { + device.SeqWindow[bytePos] |= (1 << bitPos) + } + } + + // Update database + if err := userDB.UpdateDeviceSequence( + device.DeviceID, device.CurrentSeq, device.SeqWindow, device.NonceCache, + ); err != nil { + return fmt.Errorf("failed to update device sequence: %w", err) + } + return nil +} + +// getDeviceMutex retrieves or creates a mutex for the given device ID +func (dm *DeviceMutexManager) getDeviceMutex(deviceID string) *deviceMutex { + // Try to load existing mutex + if value, exists := dm.mutexes.Load(deviceID); exists { + mutex := value.(*deviceMutex) + mutex.lastUsed = time.Now() + return mutex + } + + // Create new mutex + newMutex := &deviceMutex{ + lastUsed: time.Now(), + deviceID: deviceID, + } + + // Store and return the mutex (LoadOrStore handles race conditions) + actual, _ := dm.mutexes.LoadOrStore(deviceID, newMutex) + actualMutex := actual.(*deviceMutex) + actualMutex.lastUsed = time.Now() + return actualMutex +} + +// lockDevice acquires a lock for the specified device, preventing race conditions +// in authentication state updates. The returned function must be called to release the lock. +func (dm *DeviceMutexManager) lockDevice(deviceID string) func() { + mutex := dm.getDeviceMutex(deviceID) + mutex.mu.Lock() + + return func() { + mutex.mu.Unlock() + } +} + +// cleanup removes unused mutexes to prevent memory leaks +func (dm *DeviceMutexManager) cleanup() { + now := time.Now() + dm.mutexes.Range(func(key, value interface{}) bool { + mutex := value.(*deviceMutex) + if now.Sub(mutex.lastUsed) > MutexMaxIdle { + dm.mutexes.Delete(key) + log.Debug().Str("device_id", mutex.deviceID).Msg("cleaned up unused device mutex") + } + return true + }) +} + +// startCleanupRoutine starts a background goroutine to periodically clean up unused mutexes +func (dm *DeviceMutexManager) startCleanupRoutine() { + go func() { + ticker := time.NewTicker(MutexCleanupInterval) + defer ticker.Stop() + + for range ticker.C { + dm.cleanup() + } + }() +} + +// GetDeviceMutex is a convenience function to get a mutex for a device +func GetDeviceMutex(deviceID string) *deviceMutex { + return globalDeviceMutexManager.getDeviceMutex(deviceID) +} + +// LockDevice is a convenience function to lock a device +func LockDevice(deviceID string) func() { + return globalDeviceMutexManager.lockDevice(deviceID) +} + +func init() { + // Start cleanup routine for device mutexes + globalDeviceMutexManager.startCleanupRoutine() +} + +func shiftWindowRight(window []byte) { + carry := byte(0) + for i := len(window) - 1; i >= 0; i-- { + newCarry := (window[i] & 1) << 7 + window[i] = (window[i] >> 1) | carry + carry = newCarry + } +} + +func IsAuthenticatedConnection(r *http.Request) bool { + return !isLocalhost(r.RemoteAddr) +} + +func GetDeviceFromContext(ctx context.Context) *database.Device { + if device, ok := ctx.Value(deviceKey("device")).(*database.Device); ok { + return device + } + return nil +} diff --git a/pkg/api/middleware/auth_test.go b/pkg/api/middleware/auth_test.go new file mode 100644 index 000000000..9aeef8ca2 --- /dev/null +++ b/pkg/api/middleware/auth_test.go @@ -0,0 +1,523 @@ +// Zaparoo Core +// Copyright (c) 2025 The Zaparoo Project Contributors. +// SPDX-License-Identifier: GPL-3.0-or-later +// +// This file is part of Zaparoo Core. +// +// Zaparoo Core is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Zaparoo Core is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Zaparoo Core. If not, see . + +package middleware + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/ZaparooProject/zaparoo-core/v2/pkg/database" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/testing/helpers" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func TestAuthMiddleware_LocalhostBypass(t *testing.T) { + t.Parallel() + // Setup + userDB := helpers.NewMockUserDBI() + db := &database.Database{UserDB: userDB} + + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("success")) + }) + + middleware := AuthMiddleware(db) + wrappedHandler := middleware(handler) + + tests := []struct { + name string + remoteAddr string + expectPass bool + }{ + {"localhost with port", "127.0.0.1:12345", true}, + {"localhost without port", "127.0.0.1", true}, + {"localhost name with port", "localhost:8080", true}, + {"localhost name without port", "localhost", true}, + {"IPv6 loopback", "[::1]:8080", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + req := httptest.NewRequest(http.MethodPost, "/api/test", bytes.NewReader([]byte(`{"test": "data"}`))) + req.RemoteAddr = tt.remoteAddr + + w := httptest.NewRecorder() + wrappedHandler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code, "localhost should bypass auth") + assert.Equal(t, "success", w.Body.String()) + }) + } +} + +func TestAuthMiddleware_RemoteRequiresAuth(t *testing.T) { + t.Parallel() + // Setup + userDB := helpers.NewMockUserDBI() + // Mock the GetDeviceByAuthToken call to return an error for empty token + // The middleware calls this once per test case (we have 2 test cases) + userDB.On("GetDeviceByAuthToken", "").Return((*database.Device)(nil), assert.AnError).Times(2) + + db := &database.Database{UserDB: userDB} + + handler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + t.Error("handler should not be called for unauthenticated remote requests") + }) + + middleware := AuthMiddleware(db) + wrappedHandler := middleware(handler) + + tests := []struct { + name string + remoteAddr string + }{ + {"private network IP", "192.168.1.100:5000"}, + {"public IP", "8.8.8.8:80"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Send regular JSON without proper auth fields - should fail at auth token lookup + req := httptest.NewRequest(http.MethodPost, "/api/test", bytes.NewReader([]byte(`{"test": "data"}`))) + req.RemoteAddr = tt.remoteAddr + + w := httptest.NewRecorder() + wrappedHandler.ServeHTTP(w, req) + + // Should fail because auth token is empty/invalid + assert.Equal(t, http.StatusUnauthorized, w.Code, "remote should fail auth") + assert.Contains(t, w.Body.String(), "Invalid auth token") + }) + } + + userDB.AssertExpectations(t) +} + +func TestAuthMiddleware_EncryptedRequest(t *testing.T) { + t.Parallel() + // Create a mock device with known shared secret + testSecret := []byte("test-secret-key-32-bytes-long-ok") + testDevice := &database.Device{ + DeviceID: "test-device-id", + DeviceName: "Test Device", + AuthTokenHash: "test-token-hash", + SharedSecret: testSecret, + CurrentSeq: 0, + SeqWindow: make([]byte, 8), + NonceCache: []string{}, + CreatedAt: time.Now(), + LastSeen: time.Now(), + } + + userDB := helpers.NewMockUserDBI() + userDB.On("GetDeviceByAuthToken", "test-auth-token").Return(testDevice, nil) + userDB.On("UpdateDeviceSequence", "test-device-id", uint64(1), + mock.AnythingOfType("[]uint8"), mock.AnythingOfType("[]string")).Return(nil) + + db := &database.Database{UserDB: userDB} + + // Create encrypted payload + payload := DecryptedPayload{ + JSONRPC: "2.0", + Method: "test.method", + ID: 1, + Seq: 1, + Nonce: "test-nonce-123", + } + + payloadJSON, err := json.Marshal(payload) + require.NoError(t, err) + + encryptedData, iv, err := EncryptPayload(payloadJSON, testSecret) + require.NoError(t, err) + + encRequest := EncryptedRequest{ + Encrypted: encryptedData, + IV: iv, + AuthToken: "test-auth-token", + } + + encRequestJSON, err := json.Marshal(encRequest) + require.NoError(t, err) + + // Test handler that checks the decrypted request + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify the request body was decrypted properly + var jsonRPC map[string]any + err := json.NewDecoder(r.Body).Decode(&jsonRPC) + assert.NoError(t, err) + + assert.Equal(t, "2.0", jsonRPC["jsonrpc"]) + assert.Equal(t, "test.method", jsonRPC["method"]) + assert.InDelta(t, float64(1), jsonRPC["id"], 0.001) // JSON unmarshals numbers as float64 + + // Verify device is in context + device := GetDeviceFromContext(r.Context()) + assert.NotNil(t, device) + assert.Equal(t, "test-device-id", device.DeviceID) + + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("authenticated success")) + }) + + middleware := AuthMiddleware(db) + wrappedHandler := middleware(handler) + + // Test successful authentication + req := httptest.NewRequest(http.MethodPost, "/api/test", bytes.NewReader(encRequestJSON)) + req.RemoteAddr = "192.168.1.100:5000" // Remote address to trigger auth + + w := httptest.NewRecorder() + wrappedHandler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "authenticated success", w.Body.String()) + userDB.AssertExpectations(t) +} + +func TestValidateSequenceAndNonce_ReplayProtection(t *testing.T) { + t.Parallel() + tests := []struct { + name string + newNonce string + description string + seqWindow []byte + nonceCache []string + currentSeq uint64 + newSeq uint64 + expectedResult bool + }{ + { + name: "first message", + currentSeq: 0, + seqWindow: make([]byte, 8), + nonceCache: []string{}, + newSeq: 1, + newNonce: "nonce1", + expectedResult: true, + description: "first message should always pass", + }, + { + name: "sequence increment", + currentSeq: 5, + seqWindow: []byte{0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + nonceCache: []string{"old-nonce"}, + newSeq: 6, + newNonce: "nonce6", + expectedResult: true, + description: "incrementing sequence should pass", + }, + { + name: "duplicate nonce", + currentSeq: 5, + seqWindow: []byte{0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + nonceCache: []string{"duplicate-nonce"}, + newSeq: 6, + newNonce: "duplicate-nonce", + expectedResult: false, + description: "duplicate nonce should be rejected", + }, + { + name: "old sequence out of window", + currentSeq: 100, + seqWindow: make([]byte, 8), + nonceCache: []string{}, + newSeq: 10, // More than 64 behind + newNonce: "nonce10", + expectedResult: false, + description: "sequence too far behind should be rejected", + }, + { + name: "sequence within window", + currentSeq: 10, + seqWindow: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + nonceCache: []string{}, + newSeq: 8, // 2 behind, within window + newNonce: "nonce8", + expectedResult: true, + description: "sequence within sliding window should pass", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + device := &database.Device{ + DeviceID: "test-device", + CurrentSeq: tt.currentSeq, + SeqWindow: tt.seqWindow, + NonceCache: tt.nonceCache, + } + + result := ValidateSequenceAndNonce(device, tt.newSeq, tt.newNonce) + assert.Equal(t, tt.expectedResult, result, tt.description) + }) + } +} + +func TestEncryptDecryptPayload(t *testing.T) { + t.Parallel() + testKey := []byte("test-encryption-key-32bytes-ok!!") + originalData := []byte(`{"jsonrpc":"2.0","method":"test","id":1}`) + + // Test encryption + encrypted, iv, err := EncryptPayload(originalData, testKey) + require.NoError(t, err) + assert.NotEmpty(t, encrypted) + assert.NotEmpty(t, iv) + + // Test decryption + decrypted, err := DecryptPayload(encrypted, iv, testKey) + require.NoError(t, err) + assert.Equal(t, originalData, decrypted) +} + +func TestEncryptDecryptPayload_WrongKey(t *testing.T) { + t.Parallel() + correctKey := []byte("correct-key-32-bytes-long-ok!!!!") + wrongKey := []byte("wrong-key-32-bytes-long-ok!!!!!!") + originalData := []byte(`{"test": "data"}`) + + // Encrypt with correct key + encrypted, iv, err := EncryptPayload(originalData, correctKey) + require.NoError(t, err) + + // Try to decrypt with wrong key - should fail + _, err = DecryptPayload(encrypted, iv, wrongKey) + require.Error(t, err, "decryption should fail with wrong key") + assert.Contains(t, err.Error(), "decryption failed") +} + +func TestIsLocalhost(t *testing.T) { + t.Parallel() + tests := []struct { + addr string + expected bool + }{ + {"127.0.0.1:8080", true}, + {"127.0.0.1", true}, + {"localhost:3000", true}, + {"localhost", true}, + {"[::1]:8080", true}, + {"::1", true}, + {"192.168.1.1:8080", false}, + {"8.8.8.8:53", false}, + {"example.com:80", false}, + } + + for _, tt := range tests { + t.Run(tt.addr, func(t *testing.T) { + t.Parallel() + result := isLocalhost(tt.addr) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestAuthMiddleware_InvalidRequests(t *testing.T) { + t.Parallel() + userDB := helpers.NewMockUserDBI() + // Mock empty auth token lookup - expect it to be called once for the missing auth token test + userDB.On("GetDeviceByAuthToken", "").Return((*database.Device)(nil), assert.AnError) + + db := &database.Database{UserDB: userDB} + + handler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + t.Error("handler should not be called for invalid requests") + }) + + middleware := AuthMiddleware(db) + wrappedHandler := middleware(handler) + + tests := []struct { + name string + body string + description string + expectedCode int + }{ + { + name: "invalid json", + body: `{"invalid json`, + expectedCode: http.StatusBadRequest, + description: "malformed JSON should be rejected", + }, + { + name: "missing auth token", + body: `{"encrypted": "data", "iv": "iv"}`, + expectedCode: http.StatusUnauthorized, + description: "missing auth token should be rejected", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/api/test", bytes.NewReader([]byte(tt.body))) + req.RemoteAddr = "192.168.1.100:5000" // Remote address to trigger auth + + w := httptest.NewRecorder() + wrappedHandler.ServeHTTP(w, req) + + assert.Equal(t, tt.expectedCode, w.Code, tt.description) + }) + } + + userDB.AssertExpectations(t) +} + +func TestGetDeviceFromContext(t *testing.T) { + t.Parallel() + // Test with device in context + device := &database.Device{DeviceID: "test-device"} + ctx := context.WithValue(context.Background(), deviceKey("device"), device) + + result := GetDeviceFromContext(ctx) + assert.Equal(t, device, result) + + // Test with no device in context + emptyCtx := context.Background() + result = GetDeviceFromContext(emptyCtx) + assert.Nil(t, result) + + // Test with wrong type in context + badCtx := context.WithValue(context.Background(), deviceKey("device"), "not-a-device") + result = GetDeviceFromContext(badCtx) + assert.Nil(t, result) +} + +// TestAuthMiddleware_ConcurrentRequests verifies that the race condition fix +// prevents concurrent requests from bypassing replay protection +func TestAuthMiddleware_ConcurrentRequests(t *testing.T) { + t.Parallel() + + // This test verifies the mutex locking works correctly + // We'll just test that concurrent access to the mutex manager is safe + + const numConcurrentRequests = 20 + const deviceID = "test-device-concurrent" + + done := make(chan struct{}, numConcurrentRequests) + var lockAcquired int32 + + for i := 0; i < numConcurrentRequests; i++ { + go func() { + defer func() { done <- struct{}{} }() + + // Acquire device lock - this should be thread-safe + unlockDevice := LockDevice(deviceID) + + // Critical section - only one goroutine should be here at a time + current := atomic.AddInt32(&lockAcquired, 1) + if current != 1 { + t.Errorf("Race condition detected: %d goroutines in critical section", current) + } + + // Simulate some work + time.Sleep(1 * time.Millisecond) + + atomic.AddInt32(&lockAcquired, -1) + unlockDevice() + }() + } + + // Wait for all requests to complete + for i := 0; i < numConcurrentRequests; i++ { + <-done + } + + // Verify no race conditions occurred + assert.Equal(t, int32(0), atomic.LoadInt32(&lockAcquired), "All locks should be released") +} + +// TestDeviceMutexManager_Cleanup verifies mutex cleanup works correctly +func TestDeviceMutexManager_Cleanup(t *testing.T) { + t.Parallel() + + dm := &DeviceMutexManager{} + + // Create some mutexes + mutex1 := dm.getDeviceMutex("device1") + mutex2 := dm.getDeviceMutex("device2") + mutex3 := dm.getDeviceMutex("device3") + + require.NotNil(t, mutex1) + require.NotNil(t, mutex2) + require.NotNil(t, mutex3) + + // Age some mutexes + mutex1.lastUsed = time.Now().Add(-31 * time.Minute) // Should be cleaned + mutex2.lastUsed = time.Now().Add(-20 * time.Minute) // Should remain + mutex3.lastUsed = time.Now() // Should remain + + // Run cleanup + dm.cleanup() + + // Check that only old mutex was removed + _, exists1 := dm.mutexes.Load("device1") + _, exists2 := dm.mutexes.Load("device2") + _, exists3 := dm.mutexes.Load("device3") + + assert.False(t, exists1, "Old mutex should be cleaned up") + assert.True(t, exists2, "Recent mutex should remain") + assert.True(t, exists3, "Current mutex should remain") +} + +// TestDeviceMutexManager_ConcurrentAccess verifies thread safety of mutex manager +func TestDeviceMutexManager_ConcurrentAccess(t *testing.T) { + t.Parallel() + + dm := &DeviceMutexManager{} + const numGoroutines = 50 + const deviceID = "concurrent-test-device" + + done := make(chan struct{}, numGoroutines) + + // Launch multiple goroutines that get/create mutex for same device + for i := 0; i < numGoroutines; i++ { + go func() { + defer func() { done <- struct{}{} }() + + mutex := dm.getDeviceMutex(deviceID) + require.NotNil(t, mutex) + assert.Equal(t, deviceID, mutex.deviceID) + }() + } + + // Wait for all goroutines to complete + for i := 0; i < numGoroutines; i++ { + <-done + } + + // Verify only one mutex was created for the device + value, exists := dm.mutexes.Load(deviceID) + assert.True(t, exists, "Mutex should exist") + + mutex := value.(*deviceMutex) + assert.Equal(t, deviceID, mutex.deviceID) +} diff --git a/pkg/api/middleware/ratelimit_test.go b/pkg/api/middleware/ratelimit_test.go new file mode 100644 index 000000000..891d7002b --- /dev/null +++ b/pkg/api/middleware/ratelimit_test.go @@ -0,0 +1,206 @@ +// Zaparoo Core +// Copyright (c) 2025 The Zaparoo Project Contributors. +// SPDX-License-Identifier: GPL-3.0-or-later +// +// This file is part of Zaparoo Core. +// +// Zaparoo Core is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Zaparoo Core is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Zaparoo Core. If not, see . + +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "golang.org/x/time/rate" +) + +func TestIPRateLimiter_BasicFunctionality(t *testing.T) { + t.Parallel() + limiter := NewIPRateLimiter() + + // Get limiter for IP + rl := limiter.GetLimiter("192.168.1.100") + assert.NotNil(t, rl) + + // Should allow initial requests up to burst size + for i := range BurstSize { + allowed := rl.Allow() + assert.True(t, allowed, "should allow request %d within burst", i+1) + } + + // Should block additional requests beyond burst + blocked := rl.Allow() + assert.False(t, blocked, "should block request beyond burst size") +} + +func TestIPRateLimiter_DifferentIPs(t *testing.T) { + t.Parallel() + limiter := NewIPRateLimiter() + + // Get limiters for different IPs + rl1 := limiter.GetLimiter("192.168.1.100") + rl2 := limiter.GetLimiter("192.168.1.101") + + // Should be different limiters + assert.NotSame(t, rl1, rl2) + + // Exhaust first limiter + for range BurstSize { + rl1.Allow() + } + + // First limiter should be blocked + assert.False(t, rl1.Allow()) + + // Second limiter should still allow requests + assert.True(t, rl2.Allow()) +} + +func TestIPRateLimiter_SameIPReuse(t *testing.T) { + t.Parallel() + limiter := NewIPRateLimiter() + + // Get limiter for same IP twice + rl1 := limiter.GetLimiter("192.168.1.100") + rl2 := limiter.GetLimiter("192.168.1.100") + + // Should be same limiter instance + assert.Same(t, rl1, rl2) +} + +func TestHTTPRateLimitMiddleware_Allow(t *testing.T) { + t.Parallel() + limiter := NewIPRateLimiter() + + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("success")) + }) + + middleware := HTTPRateLimitMiddleware(limiter) + wrappedHandler := middleware(handler) + + // Should allow initial requests + for i := range BurstSize { + req := httptest.NewRequest(http.MethodPost, "/api/test", http.NoBody) + req.RemoteAddr = "192.168.1.100:12345" + + w := httptest.NewRecorder() + wrappedHandler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code, "should allow request %d", i+1) + assert.Equal(t, "success", w.Body.String()) + } +} + +func TestHTTPRateLimitMiddleware_Block(t *testing.T) { + t.Parallel() + limiter := NewIPRateLimiter() + + handler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + t.Error("handler should not be called when rate limited") + }) + + middleware := HTTPRateLimitMiddleware(limiter) + wrappedHandler := middleware(handler) + + // Exhaust rate limit + ip := "192.168.1.100:12345" + ipLimiter := limiter.GetLimiter("192.168.1.100") + for range BurstSize { + ipLimiter.Allow() + } + + // Next request should be blocked + req := httptest.NewRequest(http.MethodPost, "/api/test", http.NoBody) + req.RemoteAddr = ip + + w := httptest.NewRecorder() + wrappedHandler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusTooManyRequests, w.Code) + assert.Contains(t, w.Body.String(), "Too Many Requests") +} + +func TestIPRateLimiter_Cleanup(t *testing.T) { + t.Parallel() + limiter := NewIPRateLimiter() + + // Add a limiter with old timestamp manually + limiter.limiters["old.ip"] = &rateLimiterEntry{ + limiter: rate.NewLimiter(rate.Limit(float64(RequestsPerMinute)/60.0), BurstSize), + lastSeen: time.Now().Add(-15 * time.Minute), // Old timestamp + } + + // Add a limiter with recent timestamp manually + limiter.limiters["new.ip"] = &rateLimiterEntry{ + limiter: rate.NewLimiter(rate.Limit(float64(RequestsPerMinute)/60.0), BurstSize), + lastSeen: time.Now(), // Recent timestamp + } + + // Verify both exist + assert.Len(t, limiter.limiters, 2) + + // Run cleanup + limiter.Cleanup() + + // Old entry should be removed, new one should remain + assert.Len(t, limiter.limiters, 1) + assert.Contains(t, limiter.limiters, "new.ip") + assert.NotContains(t, limiter.limiters, "old.ip") +} + +func TestHTTPRateLimitMiddleware_IPExtraction(t *testing.T) { + t.Parallel() + limiter := NewIPRateLimiter() + + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + middleware := HTTPRateLimitMiddleware(limiter) + wrappedHandler := middleware(handler) + + tests := []struct { + name string + remoteAddr string + expectedIP string + }{ + {"with port", "192.168.1.100:12345", "192.168.1.100"}, + {"without port", "192.168.1.100", "192.168.1.100"}, + {"IPv6 with port", "[2001:db8::1]:8080", "2001:db8::1"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + req := httptest.NewRequest(http.MethodPost, "/api/test", http.NoBody) + req.RemoteAddr = tt.remoteAddr + + w := httptest.NewRecorder() + wrappedHandler.ServeHTTP(w, req) + + // Should succeed (IP extraction worked) + assert.Equal(t, http.StatusOK, w.Code) + + // Verify the correct IP was used for rate limiting + ipLimiter := limiter.GetLimiter(tt.expectedIP) + assert.NotNil(t, ipLimiter) + }) + } +} diff --git a/pkg/api/pairing.go b/pkg/api/pairing.go new file mode 100644 index 000000000..683ef5b16 --- /dev/null +++ b/pkg/api/pairing.go @@ -0,0 +1,271 @@ +// Zaparoo Core +// Copyright (c) 2025 The Zaparoo Project Contributors. +// SPDX-License-Identifier: GPL-3.0-or-later +// +// This file is part of Zaparoo Core. +// +// Zaparoo Core is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Zaparoo Core is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Zaparoo Core. If not, see . + +package api + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "net/http" + "sync" + "time" + + "github.com/ZaparooProject/zaparoo-core/v2/pkg/database" + "github.com/google/uuid" + "github.com/rs/zerolog/log" + "golang.org/x/crypto/hkdf" +) + +const ( + PairingTokenExpiry = 5 * time.Minute + PairingAttemptLimit = 10 +) + +type PairingSession struct { + CreatedAt time.Time + Token string + Challenge []byte + Salt []byte + Attempts int +} + +type PairingManager struct { + sessions map[string]*PairingSession + mu sync.RWMutex +} + +type PairingInitiateRequest struct { + DeviceName string `json:"deviceName"` +} + +type PairingInitiateResponse struct { + PairingToken string `json:"pairingToken"` + ExpiresIn int `json:"expiresIn"` +} + +type PairingCompleteRequest struct { + PairingToken string `json:"pairingToken"` + Verifier string `json:"verifier"` + DeviceName string `json:"deviceName"` +} + +type PairingCompleteResponse struct { + DeviceID string `json:"deviceId"` + AuthToken string `json:"authToken"` + SharedSecret string `json:"sharedSecret"` // Base64 encoded +} + +var pairingManager = &PairingManager{ + sessions: make(map[string]*PairingSession), +} + +func init() { + // Start cleanup routine + go pairingManager.cleanup() +} + +func (pm *PairingManager) createSession() (*PairingSession, error) { + pm.mu.Lock() + defer pm.mu.Unlock() + + // Generate random token and challenge + token := uuid.New().String() + challenge := make([]byte, 32) + if _, err := rand.Read(challenge); err != nil { + return nil, fmt.Errorf("failed to generate challenge: %w", err) + } + + // Generate random salt (32 bytes for SHA-256) + salt := make([]byte, 32) + if _, err := rand.Read(salt); err != nil { + return nil, fmt.Errorf("failed to generate salt: %w", err) + } + + session := &PairingSession{ + Token: token, + Challenge: challenge, + Salt: salt, + CreatedAt: time.Now(), + Attempts: 0, + } + + pm.sessions[token] = session + + log.Debug().Str("token", token[:8]+"...").Msg("created pairing session") + return session, nil +} + +func (pm *PairingManager) consumeSession(token string) (*PairingSession, bool) { + pm.mu.Lock() + defer pm.mu.Unlock() + + session, exists := pm.sessions[token] + if !exists { + return nil, false + } + + // Check if expired + if time.Since(session.CreatedAt) > PairingTokenExpiry { + delete(pm.sessions, token) + return nil, false + } + + // Increment attempts + session.Attempts++ + + // Check attempt limit + if session.Attempts >= PairingAttemptLimit { + delete(pm.sessions, token) + return nil, false + } + + return session, true +} + +func (pm *PairingManager) cleanup() { + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + + for range ticker.C { + pm.mu.Lock() + now := time.Now() + for token, session := range pm.sessions { + if now.Sub(session.CreatedAt) > PairingTokenExpiry { + delete(pm.sessions, token) + } + } + pm.mu.Unlock() + } +} + +func handlePairingInitiate(_ *database.Database) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + session, err := pairingManager.createSession() + if err != nil { + log.Error().Err(err).Msg("failed to create pairing session") + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + response := PairingInitiateResponse{ + PairingToken: session.Token, + ExpiresIn: int(PairingTokenExpiry.Seconds()), + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(response); err != nil { + log.Error().Err(err).Msg("failed to encode pairing response") + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + log.Info().Str("token", session.Token[:8]+"...").Msg("pairing session initiated") + } +} + +func handlePairingComplete(db *database.Database) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var req PairingCompleteRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "Invalid request format", http.StatusBadRequest) + return + } + + if req.PairingToken == "" || req.Verifier == "" || req.DeviceName == "" { + http.Error(w, "Missing required fields", http.StatusBadRequest) + return + } + + // Get and consume pairing session + session, exists := pairingManager.consumeSession(req.PairingToken) + if !exists { + http.Error(w, "Invalid or expired pairing token", http.StatusBadRequest) + return + } + + // Derive shared secret using HKDF (challenge + verifier) + combinedSecret := make([]byte, len(session.Challenge)+len(req.Verifier)) + copy(combinedSecret, session.Challenge) + copy(combinedSecret[len(session.Challenge):], req.Verifier) + sharedSecret := make([]byte, 32) // 256 bits for AES-256 + + // Construct context-specific info string for domain separation + info := []byte("zaparoo-pairing-v1|" + req.PairingToken + "|" + req.DeviceName) + + kdf := hkdf.New(sha256.New, combinedSecret, session.Salt, info) + if _, err := kdf.Read(sharedSecret); err != nil { + log.Error().Err(err).Msg("failed to derive shared secret") + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + // Generate auth token + authToken := uuid.New().String() + + // Create device in database + device, err := db.UserDB.CreateDevice(req.DeviceName, authToken, sharedSecret) + if err != nil { + log.Error().Err(err).Msg("failed to create device") + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + // Remove session from manager + pairingManager.mu.Lock() + delete(pairingManager.sessions, req.PairingToken) + pairingManager.mu.Unlock() + + response := PairingCompleteResponse{ + DeviceID: device.DeviceID, + AuthToken: authToken, + SharedSecret: hex.EncodeToString(sharedSecret), + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(response); err != nil { + log.Error().Err(err).Msg("failed to encode pairing response") + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + log.Info(). + Str("device_id", device.DeviceID). + Str("device_name", device.DeviceName). + Msg("device paired successfully") + } +} + +// QRCodeData represents the data embedded in the pairing QR code +type QRCodeData struct { + Address string `json:"address"` + Token string `json:"token"` +} diff --git a/pkg/api/server.go b/pkg/api/server.go index 4e0dc2f4d..7e9fdd4eb 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -43,6 +43,7 @@ import ( "github.com/ZaparooProject/zaparoo-core/v2/pkg/assets" "github.com/ZaparooProject/zaparoo-core/v2/pkg/config" "github.com/ZaparooProject/zaparoo-core/v2/pkg/database" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/database/userdb" "github.com/ZaparooProject/zaparoo-core/v2/pkg/helpers" "github.com/ZaparooProject/zaparoo-core/v2/pkg/platforms" "github.com/ZaparooProject/zaparoo-core/v2/pkg/service/state" @@ -427,7 +428,7 @@ func buildDynamicAllowedOrigins(baseOrigins, localIPs []string, port int, custom } // broadcastNotifications consumes and broadcasts all incoming API -// notifications to all connected clients. +// notifications to all connected clients with appropriate encryption. func broadcastNotifications( st *state.State, session *melody.Melody, @@ -451,11 +452,61 @@ func broadcastNotifications( continue } - // TODO: this will not work with encryption - err = session.Broadcast(data) - if err != nil { - log.Error().Err(err).Msg("broadcasting notification") - } + // Broadcast to localhost sessions (unencrypted) + _ = session.BroadcastFilter(data, func(s *melody.Session) bool { + rawIP := strings.SplitN(s.Request.RemoteAddr, ":", 2) + clientIP := net.ParseIP(rawIP[0]) + return clientIP.IsLoopback() + }) + + // Broadcast to authenticated remote sessions (encrypted) + _ = session.BroadcastFilter(nil, func(s *melody.Session) bool { + // Check if this is a remote connection + rawIP := strings.SplitN(s.Request.RemoteAddr, ":", 2) + clientIP := net.ParseIP(rawIP[0]) + if clientIP.IsLoopback() { + return false // Skip localhost + } + + // Check if session is authenticated + device, authenticated := s.Get("device") + if !authenticated { + return false // Skip unauthenticated + } + + // Encrypt notification for this session + deviceObj, ok := device.(*database.Device) + if !ok { + log.Error().Msg("invalid device type in session") + return false + } + encrypted, iv, err := apimiddleware.EncryptPayload(data, deviceObj.SharedSecret) + if err != nil { + log.Error().Err(err).Str("device_id", deviceObj.DeviceID).Msg("failed to encrypt notification") + return false + } + + encResponse := apimiddleware.EncryptedRequest{ + Encrypted: encrypted, + IV: iv, + AuthToken: "", // Not needed for notifications + } + + encData, err := json.Marshal(encResponse) + if err != nil { + log.Error().Err(err).Str("device_id", deviceObj.DeviceID). + Msg("failed to marshal encrypted notification") + return false + } + + // Send encrypted data to this session + if err := s.Write(encData); err != nil { + log.Error().Err(err).Str("device_id", deviceObj.DeviceID). + Msg("failed to send encrypted notification") + } + + return false // Don't include in broadcast since we already sent manually + }) } } } @@ -547,13 +598,50 @@ func handleWSMessage( rawIP := strings.SplitN(session.Request.RemoteAddr, ":", 2) clientIP := net.ParseIP(rawIP[0]) + isLocal := clientIP.IsLoopback() + + // Handle authentication for remote connections + if !isLocal { + device, authenticated := session.Get("device") + if !authenticated { + // First message must be authentication + err := handleWSAuthentication(session, msg, db) + if err != nil { + log.Error().Err(err).Msg("WebSocket authentication failed") + _ = session.Close() + } + return + } + + // Decrypt message for authenticated remote connection + deviceObj, ok := device.(*database.Device) + if !ok { + log.Error().Msg("invalid device type in session") + err := sendWSError(session, uuid.Nil, JSONRPCErrorInternalError) + if err != nil { + log.Error().Err(err).Msg("failed to send WebSocket error") + } + return + } + decryptedMsg, err := handleWSDecryption(session, msg, deviceObj, db) + if err != nil { + log.Error().Err(err).Msg("WebSocket decryption failed") + err := sendWSError(session, uuid.Nil, JSONRPCErrorInvalidRequest) + if err != nil { + log.Error().Err(err).Msg("error sending decryption error response") + } + return + } + msg = decryptedMsg + } + env := requests.RequestEnv{ Platform: platform, Config: cfg, State: st, Database: db, TokenQueue: inTokenQueue, - IsLocal: clientIP.IsLoopback(), + IsLocal: isLocal, } id, resp, rpcError := processRequestObject(methodMap, env, msg) @@ -563,6 +651,23 @@ func handleWSMessage( log.Error().Err(err).Msg("error sending error response") } } else { + // Encrypt response for remote authenticated connections + if !isLocal { + if device, authenticated := session.Get("device"); authenticated { + deviceObj, ok := device.(*database.Device) + if !ok { + log.Error().Msg("invalid device type in session") + return + } + err := sendWSResponseEncrypted(session, id, resp, deviceObj) + if err != nil { + log.Error().Err(err).Msg("error sending encrypted response") + } + return + } + } + + // Send unencrypted response for localhost err := sendWSResponse(session, id, resp) if err != nil { log.Error().Err(err).Msg("error sending response") @@ -571,6 +676,144 @@ func handleWSMessage( } } +type WSAuthMessage struct { + AuthToken string `json:"authToken"` +} + +func handleWSAuthentication(session *melody.Session, msg []byte, db *database.Database) error { + var authMsg WSAuthMessage + if err := json.Unmarshal(msg, &authMsg); err != nil { + return fmt.Errorf("invalid auth message format: %w", err) + } + + if authMsg.AuthToken == "" { + return errors.New("missing auth token") + } + + // Validate auth token and get device + device, err := db.UserDB.GetDeviceByAuthToken(authMsg.AuthToken) + if err != nil { + return fmt.Errorf("invalid auth token: %w", err) + } + + // Store device in session + session.Set("device", device) + + // Send authentication success response + authResponse := map[string]any{ + "authenticated": true, + "device_id": device.DeviceID, + } + + responseData, _ := json.Marshal(authResponse) + err = session.Write(responseData) + if err != nil { + return fmt.Errorf("failed to send auth response: %w", err) + } + + log.Debug().Str("device_id", device.DeviceID).Msg("WebSocket authenticated") + return nil +} + +func handleWSDecryption(_ *melody.Session, msg []byte, device *database.Device, db *database.Database) ([]byte, error) { + // Parse encrypted message + var encMsg apimiddleware.EncryptedRequest + if err := json.Unmarshal(msg, &encMsg); err != nil { + return nil, fmt.Errorf("invalid encrypted message format: %w", err) + } + + // Decrypt payload (we can reuse the middleware function by creating a temporary import) + decryptedPayload, err := apimiddleware.DecryptPayload(encMsg.Encrypted, encMsg.IV, device.SharedSecret) + if err != nil { + return nil, fmt.Errorf("decryption failed: %w", err) + } + + // Parse and validate sequence/nonce + var payload apimiddleware.DecryptedPayload + if unmarshalErr := json.Unmarshal(decryptedPayload, &payload); unmarshalErr != nil { + return nil, fmt.Errorf("invalid decrypted payload: %w", unmarshalErr) + } + + // CRITICAL SECTION: Acquire device lock to prevent race conditions + // between validation and database update + unlockDevice := apimiddleware.LockDevice(device.DeviceID) + defer unlockDevice() + + // Re-fetch device state under lock to get latest sequence/nonce state + freshDevice, err := db.UserDB.GetDeviceByID(device.DeviceID) + if err != nil { + return nil, fmt.Errorf("failed to re-fetch device under lock: %w", err) + } + + // Validate sequence and nonce with fresh device state + if !apimiddleware.ValidateSequenceAndNonce(freshDevice, payload.Seq, payload.Nonce) { + return nil, errors.New("invalid sequence or replay detected") + } + + // Update device state under lock + userDB, ok := db.UserDB.(*userdb.UserDB) + if !ok { + return nil, errors.New("failed to cast UserDB to concrete type") + } + if updateErr := apimiddleware.UpdateDeviceState(userDB, freshDevice, payload.Seq, payload.Nonce); updateErr != nil { + return nil, fmt.Errorf("failed to update device state: %w", updateErr) + } + + // Return JSON-RPC payload without sequence/nonce + originalPayload := map[string]any{ + "jsonrpc": payload.JSONRPC, + "method": payload.Method, + "id": payload.ID, + } + if payload.Params != nil { + originalPayload["params"] = payload.Params + } + + result, err := json.Marshal(originalPayload) + if err != nil { + return nil, fmt.Errorf("failed to marshal payload: %w", err) + } + return result, nil +} + +func sendWSResponseEncrypted(session *melody.Session, id uuid.UUID, result any, device *database.Device) error { + // Create response object + resp := models.ResponseObject{ + JSONRPC: "2.0", + ID: id, + Result: result, + } + + // Marshal to JSON + data, err := json.Marshal(resp) + if err != nil { + return fmt.Errorf("error marshalling response: %w", err) + } + + // Encrypt the response + encrypted, iv, err := apimiddleware.EncryptPayload(data, device.SharedSecret) + if err != nil { + return fmt.Errorf("failed to encrypt response: %w", err) + } + + // Send encrypted response + encResponse := apimiddleware.EncryptedRequest{ + Encrypted: encrypted, + IV: iv, + AuthToken: "", // Not needed for responses + } + + encData, err := json.Marshal(encResponse) + if err != nil { + return fmt.Errorf("failed to marshal encrypted response: %w", err) + } + + if err := session.Write(encData); err != nil { + return fmt.Errorf("failed to send encrypted response: %w", err) + } + return nil +} + func handlePostRequest( methodMap *MethodMap, platform platforms.Platform, @@ -700,29 +943,43 @@ func Start( } go broadcastNotifications(st, session, notifications) - r.Get("/api", func(w http.ResponseWriter, r *http.Request) { - err := session.HandleRequest(w, r) - if err != nil { - log.Error().Err(err).Msg("handling websocket request: latest") - } + // Pairing endpoints (no authentication required) + r.Post("/api/pair/initiate", handlePairingInitiate(db)) + r.Post("/api/pair/complete", handlePairingComplete(db)) + + // Protected API routes with authentication middleware + r.Route("/api", func(r chi.Router) { + r.Use(apimiddleware.AuthMiddleware(db)) + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + err := session.HandleRequest(w, r) + if err != nil { + log.Error().Err(err).Msg("handling websocket request: latest") + } + }) + r.Post("/", handlePostRequest(methodMap, platform, cfg, st, inTokenQueue, db)) }) - r.Post("/api", handlePostRequest(methodMap, platform, cfg, st, inTokenQueue, db)) - r.Get("/api/v0", func(w http.ResponseWriter, r *http.Request) { - err := session.HandleRequest(w, r) - if err != nil { - log.Error().Err(err).Msg("handling websocket request: v0") - } + r.Route("/api/v0", func(r chi.Router) { + r.Use(apimiddleware.AuthMiddleware(db)) + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + err := session.HandleRequest(w, r) + if err != nil { + log.Error().Err(err).Msg("handling websocket request: v0") + } + }) + r.Post("/", handlePostRequest(methodMap, platform, cfg, st, inTokenQueue, db)) }) - r.Post("/api/v0", handlePostRequest(methodMap, platform, cfg, st, inTokenQueue, db)) - r.Get("/api/v0.1", func(w http.ResponseWriter, r *http.Request) { - err := session.HandleRequest(w, r) - if err != nil { - log.Error().Err(err).Msg("handling websocket request: v0.1") - } + r.Route("/api/v0.1", func(r chi.Router) { + r.Use(apimiddleware.AuthMiddleware(db)) + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + err := session.HandleRequest(w, r) + if err != nil { + log.Error().Err(err).Msg("handling websocket request: v0.1") + } + }) + r.Post("/", handlePostRequest(methodMap, platform, cfg, st, inTokenQueue, db)) }) - r.Post("/api/v0.1", handlePostRequest(methodMap, platform, cfg, st, inTokenQueue, db)) session.HandleMessage(apimiddleware.WebSocketRateLimitHandler( rateLimiter, diff --git a/pkg/cli/cli.go b/pkg/cli/cli.go index 7fcdc2b13..60bc4f0a7 100644 --- a/pkg/cli/cli.go +++ b/pkg/cli/cli.go @@ -40,16 +40,19 @@ import ( ) type Flags struct { - Write *string - Read *bool - Run *string - Launch *string - API *string - Version *bool - Config *bool - ShowLoader *string - ShowPicker *string - Reload *bool + Write *string + Read *bool + Run *string + Launch *string + API *string + Version *bool + Config *bool + ShowLoader *string + ShowPairingCode *bool + ListDevices *bool + RevokeDevice *string + ShowPicker *string + Reload *bool } // SetupFlags defines all common CLI flags between platforms. @@ -95,6 +98,21 @@ func SetupFlags() *Flags { false, "reload config and mappings from disk", ), + ShowPairingCode: flag.Bool( + "show-pairing-code", + false, + "display QR code for device pairing", + ), + ListDevices: flag.Bool( + "list-devices", + false, + "list all paired devices", + ), + RevokeDevice: flag.String( + "revoke-device", + "", + "revoke access for device by ID", + ), } } @@ -139,8 +157,17 @@ func runFlag(cfg *config.Instance, value string) { // Post actions all remaining common flags that require the environment to be // set up. Logging is allowed. -func (f *Flags) Post(cfg *config.Instance, _ platforms.Platform) { +func (f *Flags) Post(cfg *config.Instance, pl platforms.Platform) { switch { + case *f.ShowPairingCode: + handleShowPairingCode(cfg, pl) + os.Exit(0) + case *f.ListDevices: + handleListDevices(cfg, pl) + os.Exit(0) + case isFlagPassed("revoke-device"): + handleRevokeDevice(cfg, pl, *f.RevokeDevice) + os.Exit(0) case isFlagPassed("write"): if *f.Write == "" { _, _ = fmt.Fprint(os.Stderr, "Error: write flag requires a value\n") diff --git a/pkg/cli/devices.go b/pkg/cli/devices.go new file mode 100644 index 000000000..ea725435c --- /dev/null +++ b/pkg/cli/devices.go @@ -0,0 +1,213 @@ +// Zaparoo Core +// Copyright (c) 2025 The Zaparoo Project Contributors. +// SPDX-License-Identifier: GPL-3.0-or-later +// +// This file is part of Zaparoo Core. +// +// Zaparoo Core is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Zaparoo Core is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Zaparoo Core. If not, see . + +package cli + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "os" + "strings" + "time" + + "github.com/ZaparooProject/zaparoo-core/v2/pkg/api" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/config" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/database/userdb" + "github.com/ZaparooProject/zaparoo-core/v2/pkg/platforms" + "github.com/rs/zerolog/log" +) + +func generateQRCode(data string) { + // Simple ASCII QR code placeholder - in a real implementation, you'd use a QR code library + if _, err := fmt.Print("\n=== PAIRING QR CODE ===\n"); err != nil { + log.Error().Err(err).Msg("failed to print header") + } + if _, err := fmt.Print("Scan this with your mobile app:\n\n"); err != nil { + log.Error().Err(err).Msg("failed to print instruction") + } + + // For now, just display the data - a real implementation would generate ASCII QR code + if _, err := fmt.Printf("Data: %s\n\n", data); err != nil { + log.Error().Err(err).Msg("failed to print data") + } + + // You could use a library like github.com/skip2/go-qrcode for actual QR generation + note := "Note: QR code display not yet implemented - use manual pairing with the data above\n" + if _, err := fmt.Print(note); err != nil { + log.Error().Err(err).Msg("failed to print note") + } + if _, err := fmt.Print("======================\n\n"); err != nil { + log.Error().Err(err).Msg("failed to print footer") + } +} + +func handleShowPairingCode(cfg *config.Instance, pl platforms.Platform) { + // Open user database to check pairing sessions + userDB, err := userdb.OpenUserDB(context.Background(), pl) + if err != nil { + log.Error().Err(err).Msg("failed to open user database") + if _, writeErr := fmt.Fprintf(os.Stderr, "Error opening user database: %v\n", err); writeErr != nil { + log.Error().Err(writeErr).Msg("failed to write error message") + } + os.Exit(1) + } + defer func() { _ = userDB.Close() }() + + // Create HTTP client to call our own pairing API + client := &http.Client{Timeout: 10 * time.Second} + + // Call pairing initiate endpoint + apiURL := fmt.Sprintf("http://localhost:%d/api/pair/initiate", cfg.APIPort()) + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, apiURL, strings.NewReader("{}")) + if err != nil { + log.Error().Err(err).Msg("failed to create pairing request") + if _, writeErr := fmt.Fprintf(os.Stderr, "Error creating request: %v\n", err); writeErr != nil { + log.Error().Err(writeErr).Msg("failed to write error message") + } + return + } + req.Header.Set("Content-Type", "application/json") + resp, err := client.Do(req) + if err != nil { + log.Error().Err(err).Msg("failed to initiate pairing") + if _, writeErr := fmt.Fprintf(os.Stderr, "Error initiating pairing: %v\n", err); writeErr != nil { + log.Error().Err(writeErr).Msg("failed to write error message") + } + return + } + defer func() { _ = resp.Body.Close() }() + + var pairingResp api.PairingInitiateResponse + if err := json.NewDecoder(resp.Body).Decode(&pairingResp); err != nil { + log.Error().Err(err).Msg("failed to decode pairing response") + if _, writeErr := fmt.Fprintf(os.Stderr, "Error decoding response: %v\n", err); writeErr != nil { + log.Error().Err(writeErr).Msg("failed to write error message") + } + return + } + + // Generate QR code data + serverAddr := fmt.Sprintf("localhost:%d", cfg.APIPort()) + qrData := api.QRCodeData{ + Address: serverAddr, + Token: pairingResp.PairingToken, + } + + jsonData, _ := json.Marshal(qrData) + generateQRCode(string(jsonData)) + + if _, err := fmt.Printf("Pairing token: %s\n", pairingResp.PairingToken); err != nil { + log.Error().Err(err).Msg("failed to print pairing token") + } + if _, err := fmt.Printf("Expires in: %d seconds\n", pairingResp.ExpiresIn); err != nil { + log.Error().Err(err).Msg("failed to print expiration time") + } + if _, err := fmt.Print("\nWaiting for device to pair... (Ctrl+C to cancel)\n"); err != nil { + log.Error().Err(err).Msg("failed to print waiting message") + } + + // Wait for user to cancel + select {} +} + +func handleListDevices(_ *config.Instance, pl platforms.Platform) { + // Open user database to list devices + userDB, err := userdb.OpenUserDB(context.Background(), pl) + if err != nil { + log.Error().Err(err).Msg("failed to open user database") + if _, writeErr := fmt.Fprintf(os.Stderr, "Error opening user database: %v\n", err); writeErr != nil { + log.Error().Err(writeErr).Msg("failed to write error message") + } + os.Exit(1) + } + defer func() { _ = userDB.Close() }() + + devices, err := userDB.GetAllDevices() + if err != nil { + log.Error().Err(err).Msg("failed to get devices") + if _, writeErr := fmt.Fprintf(os.Stderr, "Error getting devices: %v\n", err); writeErr != nil { + log.Error().Err(writeErr).Msg("failed to write error message") + } + return + } + + if len(devices) == 0 { + if _, err := fmt.Println("No paired devices found."); err != nil { + log.Error().Err(err).Msg("failed to print message") + } + return + } + + if _, err := fmt.Print("Paired devices:\n\n"); err != nil { + log.Error().Err(err).Msg("failed to print header") + } + if _, err := fmt.Printf("%-36s %-20s %-10s %s\n", "Device ID", "Name", "Sequence", "Last Seen"); err != nil { + log.Error().Err(err).Msg("failed to print column headers") + } + if _, err := fmt.Printf("%s\n", strings.Repeat("-", 80)); err != nil { + log.Error().Err(err).Msg("failed to print separator") + } + + for i := range devices { + device := &devices[i] + if _, err := fmt.Printf("%-36s %-20s %-10d %s\n", + device.DeviceID, + device.DeviceName, + device.CurrentSeq, + device.LastSeen.Format("2006-01-02 15:04:05"), + ); err != nil { + log.Error().Err(err).Msg("failed to print device info") + } + } +} + +func handleRevokeDevice(_ *config.Instance, pl platforms.Platform, deviceID string) { + if deviceID == "" { + if _, err := fmt.Fprint(os.Stderr, "Error: device ID is required\n"); err != nil { + log.Error().Err(err).Msg("failed to write error message") + } + os.Exit(1) + } + + // Open user database to revoke device + userDB, err := userdb.OpenUserDB(context.Background(), pl) + if err != nil { + log.Error().Err(err).Msg("failed to open user database") + if _, writeErr := fmt.Fprintf(os.Stderr, "Error opening user database: %v\n", err); writeErr != nil { + log.Error().Err(writeErr).Msg("failed to write error message") + } + os.Exit(1) + } + defer func() { _ = userDB.Close() }() + + err = userDB.DeleteDevice(deviceID) + if err != nil { + log.Error().Err(err).Msg("failed to delete device") + if _, writeErr := fmt.Fprintf(os.Stderr, "Error deleting device: %v\n", err); writeErr != nil { + log.Error().Err(writeErr).Msg("failed to write error message") + } + return + } + + if _, err := fmt.Printf("Device %s has been revoked successfully.\n", deviceID); err != nil { + log.Error().Err(err).Msg("failed to print success message") + } +} diff --git a/pkg/database/database.go b/pkg/database/database.go index ec7961a9c..e42c1a463 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -69,6 +69,18 @@ type System struct { DBID int64 } +type Device struct { + CreatedAt time.Time `json:"createdAt"` + LastSeen time.Time `json:"lastSeen"` + DeviceID string `json:"deviceId"` + DeviceName string `json:"deviceName"` + AuthTokenHash string `json:"-"` + SharedSecret []byte `json:"-"` + SeqWindow []byte `json:"-"` + NonceCache []string `json:"-"` + CurrentSeq uint64 `json:"currentSeq"` +} + type MediaTitle struct { Slug string Name string @@ -154,6 +166,12 @@ type UserDBI interface { GetZapLinkHost(host string) (bool, bool, error) UpdateZapLinkCache(url string, zapscript string) error GetZapLinkCache(url string) (string, error) + CreateDevice(deviceName, authToken string, sharedSecret []byte) (*Device, error) + GetDeviceByAuthToken(authToken string) (*Device, error) + GetDeviceByID(deviceID string) (*Device, error) + UpdateDeviceSequence(deviceID string, newSeq uint64, seqWindow []byte, nonceCache []string) error + GetAllDevices() ([]Device, error) + DeleteDevice(deviceID string) error } type MediaDBI interface { diff --git a/pkg/database/userdb/devices.go b/pkg/database/userdb/devices.go new file mode 100644 index 000000000..187e17768 --- /dev/null +++ b/pkg/database/userdb/devices.go @@ -0,0 +1,280 @@ +// Zaparoo Core +// Copyright (c) 2025 The Zaparoo Project Contributors. +// SPDX-License-Identifier: GPL-3.0-or-later +// +// This file is part of Zaparoo Core. +// +// Zaparoo Core is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Zaparoo Core is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Zaparoo Core. If not, see . + +package userdb + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "time" + + "github.com/ZaparooProject/zaparoo-core/v2/pkg/database" + "github.com/google/uuid" +) + +func (db *UserDB) CreateDevice(deviceName, authToken string, sharedSecret []byte) (*database.Device, error) { + if db.sql == nil { + return nil, ErrNullSQL + } + + deviceID := uuid.New().String() + authTokenHash := hashAuthToken(authToken) + now := time.Now().Unix() + + device := &database.Device{ + DeviceID: deviceID, + DeviceName: deviceName, + AuthTokenHash: authTokenHash, + SharedSecret: sharedSecret, + CurrentSeq: 0, + SeqWindow: make([]byte, 8), // 64-bit window + NonceCache: make([]string, 0), + CreatedAt: time.Unix(now, 0), + LastSeen: time.Unix(now, 0), + } + + nonceCacheJSON, err := json.Marshal(device.NonceCache) + if err != nil { + return nil, fmt.Errorf("failed to marshal nonce cache: %w", err) + } + + query := ` + INSERT INTO devices (device_id, device_name, auth_token_hash, shared_secret, + current_seq, seq_window, nonce_cache, created_at, last_seen) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + ` + + _, err = db.sql.ExecContext(context.Background(), query, + device.DeviceID, + device.DeviceName, + device.AuthTokenHash, + device.SharedSecret, + device.CurrentSeq, + device.SeqWindow, + string(nonceCacheJSON), + now, + now, + ) + if err != nil { + return nil, fmt.Errorf("failed to create device: %w", err) + } + + return device, nil +} + +func (db *UserDB) GetDeviceByAuthToken(authToken string) (*database.Device, error) { + if db.sql == nil { + return nil, ErrNullSQL + } + + authTokenHash := hashAuthToken(authToken) + + query := ` + SELECT device_id, device_name, auth_token_hash, shared_secret, current_seq, + seq_window, nonce_cache, created_at, last_seen + FROM devices + WHERE auth_token_hash = ? + ` + + var device database.Device + var nonceCacheJSON string + var createdAt, lastSeen int64 + + err := db.sql.QueryRowContext(context.Background(), query, authTokenHash).Scan( + &device.DeviceID, + &device.DeviceName, + &device.AuthTokenHash, + &device.SharedSecret, + &device.CurrentSeq, + &device.SeqWindow, + &nonceCacheJSON, + &createdAt, + &lastSeen, + ) + if err != nil { + return nil, fmt.Errorf("device not found: %w", err) + } + + device.CreatedAt = time.Unix(createdAt, 0) + device.LastSeen = time.Unix(lastSeen, 0) + + err = json.Unmarshal([]byte(nonceCacheJSON), &device.NonceCache) + if err != nil { + device.NonceCache = make([]string, 0) // Fallback to empty cache + } + + return &device, nil +} + +func (db *UserDB) GetDeviceByID(deviceID string) (*database.Device, error) { + if db.sql == nil { + return nil, ErrNullSQL + } + + query := ` + SELECT device_id, device_name, auth_token_hash, shared_secret, current_seq, + seq_window, nonce_cache, created_at, last_seen + FROM devices + WHERE device_id = ? + ` + + var device database.Device + var nonceCacheJSON string + var createdAt, lastSeen int64 + + err := db.sql.QueryRowContext(context.Background(), query, deviceID).Scan( + &device.DeviceID, + &device.DeviceName, + &device.AuthTokenHash, + &device.SharedSecret, + &device.CurrentSeq, + &device.SeqWindow, + &nonceCacheJSON, + &createdAt, + &lastSeen, + ) + if err != nil { + return nil, fmt.Errorf("device not found: %w", err) + } + + device.CreatedAt = time.Unix(createdAt, 0) + device.LastSeen = time.Unix(lastSeen, 0) + + err = json.Unmarshal([]byte(nonceCacheJSON), &device.NonceCache) + if err != nil { + device.NonceCache = make([]string, 0) // Fallback to empty cache + } + + return &device, nil +} + +func (db *UserDB) UpdateDeviceSequence(deviceID string, newSeq uint64, seqWindow []byte, nonceCache []string) error { + if db.sql == nil { + return ErrNullSQL + } + + nonceCacheJSON, err := json.Marshal(nonceCache) + if err != nil { + return fmt.Errorf("failed to marshal nonce cache: %w", err) + } + + query := ` + UPDATE devices + SET current_seq = ?, seq_window = ?, nonce_cache = ?, last_seen = ? + WHERE device_id = ? + ` + + _, err = db.sql.ExecContext( + context.Background(), query, newSeq, seqWindow, string(nonceCacheJSON), time.Now().Unix(), deviceID, + ) + if err != nil { + return fmt.Errorf("failed to update device sequence: %w", err) + } + + return nil +} + +func (db *UserDB) GetAllDevices() ([]database.Device, error) { + if db.sql == nil { + return nil, ErrNullSQL + } + + query := ` + SELECT device_id, device_name, auth_token_hash, shared_secret, current_seq, + seq_window, nonce_cache, created_at, last_seen + FROM devices + ORDER BY last_seen DESC + ` + + rows, err := db.sql.QueryContext(context.Background(), query) + if err != nil { + return nil, fmt.Errorf("failed to query devices: %w", err) + } + defer func() { _ = rows.Close() }() + + devices := make([]database.Device, 0) + for rows.Next() { + var device database.Device + var nonceCacheJSON string + var createdAt, lastSeen int64 + + scanErr := rows.Scan( + &device.DeviceID, + &device.DeviceName, + &device.AuthTokenHash, + &device.SharedSecret, + &device.CurrentSeq, + &device.SeqWindow, + &nonceCacheJSON, + &createdAt, + &lastSeen, + ) + if scanErr != nil { + return nil, fmt.Errorf("failed to scan device row: %w", scanErr) + } + + device.CreatedAt = time.Unix(createdAt, 0) + device.LastSeen = time.Unix(lastSeen, 0) + + err = json.Unmarshal([]byte(nonceCacheJSON), &device.NonceCache) + if err != nil { + device.NonceCache = make([]string, 0) // Fallback to empty cache + } + + devices = append(devices, device) + } + + if err = rows.Err(); err != nil { + return nil, fmt.Errorf("error reading device rows: %w", err) + } + + return devices, nil +} + +func (db *UserDB) DeleteDevice(deviceID string) error { + if db.sql == nil { + return ErrNullSQL + } + + query := `DELETE FROM devices WHERE device_id = ?` + result, err := db.sql.ExecContext(context.Background(), query, deviceID) + if err != nil { + return fmt.Errorf("failed to delete device: %w", err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } + + if rowsAffected == 0 { + return errors.New("device not found") + } + + return nil +} + +func hashAuthToken(authToken string) string { + hash := sha256.Sum256([]byte(authToken)) + return hex.EncodeToString(hash[:]) +} diff --git a/pkg/database/userdb/migrations/20250912231204_devices_auth.sql b/pkg/database/userdb/migrations/20250912231204_devices_auth.sql new file mode 100644 index 000000000..3bddc26ad --- /dev/null +++ b/pkg/database/userdb/migrations/20250912231204_devices_auth.sql @@ -0,0 +1,22 @@ +-- +goose Up +-- +goose StatementBegin +CREATE TABLE devices ( + device_id TEXT PRIMARY KEY, + device_name TEXT NOT NULL, + auth_token_hash TEXT NOT NULL UNIQUE, + shared_secret BLOB NOT NULL, + current_seq INTEGER DEFAULT 0, + seq_window BLOB, + nonce_cache TEXT DEFAULT '[]', + created_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now')), + last_seen INTEGER NOT NULL DEFAULT (strftime('%s', 'now')) +); + +CREATE INDEX idx_devices_auth_token ON devices(auth_token_hash); +CREATE INDEX idx_devices_last_seen ON devices(last_seen); +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +DROP TABLE devices; +-- +goose StatementEnd \ No newline at end of file diff --git a/pkg/testing/helpers/db_mocks.go b/pkg/testing/helpers/db_mocks.go index b95b4e178..607e07bdd 100644 --- a/pkg/testing/helpers/db_mocks.go +++ b/pkg/testing/helpers/db_mocks.go @@ -57,6 +57,11 @@ import ( "github.com/stretchr/testify/mock" ) +// Sentinel errors for mock functions +var ( + ErrMockNotConfigured = errors.New("mock not configured for this call") +) + // MockUserDBI is a mock implementation of the UserDBI interface using testify/mock type MockUserDBI struct { mock.Mock @@ -239,6 +244,81 @@ func (m *MockUserDBI) GetZapLinkCache(url string) (string, error) { return args.String(0), args.Error(1) } +// Device authentication methods +func (m *MockUserDBI) CreateDevice(deviceName, authToken string, sharedSecret []byte) (*database.Device, error) { + args := m.Called(deviceName, authToken, sharedSecret) + if device, ok := args.Get(0).(*database.Device); ok { + if err := args.Error(1); err != nil { + return device, fmt.Errorf("mock UserDBI create device failed: %w", err) + } + return device, nil + } + if err := args.Error(1); err != nil { + return nil, fmt.Errorf("mock UserDBI create device failed: %w", err) + } + return nil, ErrMockNotConfigured +} + +func (m *MockUserDBI) GetDeviceByAuthToken(authToken string) (*database.Device, error) { + args := m.Called(authToken) + if device, ok := args.Get(0).(*database.Device); ok { + if err := args.Error(1); err != nil { + return device, fmt.Errorf("mock UserDBI get device by auth token failed: %w", err) + } + return device, nil + } + if err := args.Error(1); err != nil { + return nil, fmt.Errorf("mock UserDBI get device by auth token failed: %w", err) + } + return nil, ErrMockNotConfigured +} + +func (m *MockUserDBI) GetDeviceByID(deviceID string) (*database.Device, error) { + args := m.Called(deviceID) + if device, ok := args.Get(0).(*database.Device); ok { + if err := args.Error(1); err != nil { + return device, fmt.Errorf("mock UserDBI get device by ID failed: %w", err) + } + return device, nil + } + if err := args.Error(1); err != nil { + return nil, fmt.Errorf("mock UserDBI get device by ID failed: %w", err) + } + return nil, ErrMockNotConfigured +} + +func (m *MockUserDBI) UpdateDeviceSequence( + deviceID string, newSeq uint64, seqWindow []byte, nonceCache []string, +) error { + args := m.Called(deviceID, newSeq, seqWindow, nonceCache) + if err := args.Error(0); err != nil { + return fmt.Errorf("mock UserDBI update device sequence failed: %w", err) + } + return nil +} + +func (m *MockUserDBI) GetAllDevices() ([]database.Device, error) { + args := m.Called() + if devices, ok := args.Get(0).([]database.Device); ok { + if err := args.Error(1); err != nil { + return devices, fmt.Errorf("mock UserDBI get all devices failed: %w", err) + } + return devices, nil + } + if err := args.Error(1); err != nil { + return nil, fmt.Errorf("mock UserDBI get all devices failed: %w", err) + } + return nil, nil +} + +func (m *MockUserDBI) DeleteDevice(deviceID string) error { + args := m.Called(deviceID) + if err := args.Error(0); err != nil { + return fmt.Errorf("mock UserDBI delete device failed: %w", err) + } + return nil +} + // MockMediaDBI is a mock implementation of the MediaDBI interface using testify/mock type MockMediaDBI struct { mock.Mock From c1842d4de852afb9937e9e4263e83395cda38f83 Mon Sep 17 00:00:00 2001 From: Callan Barrett Date: Sat, 13 Sep 2025 08:43:37 +0800 Subject: [PATCH 2/5] rename device to client --- pkg/api/middleware/auth.go | 200 +++++++++--------- pkg/api/middleware/auth_test.go | 120 +++++------ pkg/api/pairing.go | 24 +-- pkg/api/server.go | 74 +++---- pkg/cli/cli.go | 26 +-- pkg/cli/{devices.go => clients.go} | 50 ++--- pkg/database/database.go | 18 +- .../userdb/{devices.go => clients.go} | 158 +++++++------- ...th.sql => 20250912231204_clients_auth.sql} | 12 +- pkg/testing/helpers/db_mocks.go | 60 +++--- 10 files changed, 371 insertions(+), 371 deletions(-) rename pkg/cli/{devices.go => clients.go} (83%) rename pkg/database/userdb/{devices.go => clients.go} (55%) rename pkg/database/userdb/migrations/{20250912231204_devices_auth.sql => 20250912231204_clients_auth.sql} (64%) diff --git a/pkg/api/middleware/auth.go b/pkg/api/middleware/auth.go index 50150c858..817b144b0 100644 --- a/pkg/api/middleware/auth.go +++ b/pkg/api/middleware/auth.go @@ -41,27 +41,27 @@ import ( ) const ( - SequenceWindow = 64 // Size of sliding window for sequence numbers - NonceCacheSize = 100 // Maximum number of cached nonces + SequenceWindow = 64 // Size of sliding window for sequence numbers + NonceCacheSize = 100 // Maximum number of cached nonces MutexCleanupInterval = 10 * time.Minute // Cleanup unused mutexes every 10 minutes - MutexMaxIdle = 30 * time.Minute // Remove mutexes unused for 30 minutes + MutexMaxIdle = 30 * time.Minute // Remove mutexes unused for 30 minutes ) -type deviceKey string +type clientKey string -// DeviceMutexManager handles per-device locking to prevent race conditions +// ClientMutexManager handles per-client locking to prevent race conditions // in authentication state updates -type DeviceMutexManager struct { - mutexes sync.Map // map[string]*deviceMutex +type ClientMutexManager struct { + mutexes sync.Map // map[string]*clientMutex } -type deviceMutex struct { - mu sync.Mutex +type clientMutex struct { lastUsed time.Time - deviceID string + clientID string + mu sync.Mutex } -var globalDeviceMutexManager = &DeviceMutexManager{} +var globalClientMutexManager = &ClientMutexManager{} type EncryptedRequest struct { Encrypted string `json:"encrypted"` @@ -118,8 +118,8 @@ func AuthMiddleware(db *database.Database) func(http.Handler) http.Handler { return } - // Validate auth token and get device - device, err := db.UserDB.GetDeviceByAuthToken(encReq.AuthToken) + // Validate auth token and get client + client, err := db.UserDB.GetClientByAuthToken(encReq.AuthToken) if err != nil { tokenStr := "empty" if len(encReq.AuthToken) >= 8 { @@ -133,7 +133,7 @@ func AuthMiddleware(db *database.Database) func(http.Handler) http.Handler { } // Decrypt payload - decryptedPayload, err := DecryptPayload(encReq.Encrypted, encReq.IV, device.SharedSecret) + decryptedPayload, err := DecryptPayload(encReq.Encrypted, encReq.IV, client.SharedSecret) if err != nil { log.Error().Err(err).Msg("failed to decrypt payload") http.Error(w, "Decryption failed", http.StatusBadRequest) @@ -148,23 +148,23 @@ func AuthMiddleware(db *database.Database) func(http.Handler) http.Handler { return } - // CRITICAL SECTION: Acquire device lock to prevent race conditions + // Acquire client lock to prevent race conditions // between validation and database update - unlockDevice := LockDevice(device.DeviceID) - defer unlockDevice() + unlockClient := LockClient(client.ClientID) + defer unlockClient() - // Re-fetch device state under lock to get latest sequence/nonce state - freshDevice, err := db.UserDB.GetDeviceByAuthToken(encReq.AuthToken) + // Re-fetch client state under lock to get latest sequence/nonce state + freshClient, err := db.UserDB.GetClientByAuthToken(encReq.AuthToken) if err != nil { - log.Error().Err(err).Msg("failed to re-fetch device under lock") + log.Error().Err(err).Msg("failed to re-fetch client under lock") http.Error(w, "Internal server error", http.StatusInternalServerError) return } - // Validate sequence number and nonce with fresh device state - if !ValidateSequenceAndNonce(freshDevice, payload.Seq, payload.Nonce) { + // Validate sequence number and nonce with fresh client state + if !ValidateSequenceAndNonce(freshClient, payload.Seq, payload.Nonce) { log.Warn(). - Str("device_id", freshDevice.DeviceID). + Str("client_id", freshClient.ClientID). Uint64("seq", payload.Seq). Str("nonce", payload.Nonce). Msg("invalid sequence or replay attack detected") @@ -172,21 +172,21 @@ func AuthMiddleware(db *database.Database) func(http.Handler) http.Handler { return } - // Update device state (sequence and nonce cache) - updatedDevice := *freshDevice - updateDeviceSequenceAndNonce(&updatedDevice, payload.Seq, payload.Nonce) + // Update client state (sequence and nonce cache) + updatedClient := *freshClient + updateClientSequenceAndNonce(&updatedClient, payload.Seq, payload.Nonce) // Save to database (still under lock) - if updateErr := db.UserDB.UpdateDeviceSequence( - updatedDevice.DeviceID, updatedDevice.CurrentSeq, updatedDevice.SeqWindow, updatedDevice.NonceCache, + if updateErr := db.UserDB.UpdateClientSequence( + updatedClient.ClientID, updatedClient.CurrentSeq, updatedClient.SeqWindow, updatedClient.NonceCache, ); updateErr != nil { - log.Error().Err(updateErr).Msg("failed to update device state") + log.Error().Err(updateErr).Msg("failed to update client state") http.Error(w, "Internal server error", http.StatusInternalServerError) return } - // Update the device pointer for context (use fresh device with updates) - device = &updatedDevice + // Update the client pointer for context (use fresh client with updates) + client = &updatedClient // Replace request body with decrypted JSON-RPC payload originalPayload := map[string]any{ @@ -209,12 +209,12 @@ func AuthMiddleware(db *database.Database) func(http.Handler) http.Handler { r.Body = io.NopCloser(bytes.NewReader(newBody)) r.ContentLength = int64(len(newBody)) - // Store device in context for potential use by handlers - ctx := context.WithValue(r.Context(), deviceKey("device"), device) + // Store client in context for potential use by handlers + ctx := context.WithValue(r.Context(), clientKey("client"), client) r = r.WithContext(ctx) log.Debug(). - Str("device_id", device.DeviceID). + Str("client_id", client.ClientID). Str("method", payload.Method). Uint64("seq", payload.Seq). Msg("authenticated request processed") @@ -284,16 +284,16 @@ func EncryptPayload(data, key []byte) (encrypted, iv string, err error) { base64.StdEncoding.EncodeToString(ivBytes), nil } -func ValidateSequenceAndNonce(device *database.Device, seq uint64, nonce string) bool { +func ValidateSequenceAndNonce(client *database.Client, seq uint64, nonce string) bool { // Check if nonce was recently used (replay protection) - if slices.Contains(device.NonceCache, nonce) { + if slices.Contains(client.NonceCache, nonce) { return false } // Validate sequence number with sliding window - if seq <= device.CurrentSeq { + if seq <= client.CurrentSeq { // Check if sequence is within acceptable window - diff := device.CurrentSeq - seq + diff := client.CurrentSeq - seq if diff >= SequenceWindow { return false // Too old } @@ -303,8 +303,8 @@ func ValidateSequenceAndNonce(device *database.Device, seq uint64, nonce string) bytePos := windowPos / 8 bitPos := windowPos % 8 - if bytePos < uint64(len(device.SeqWindow)) { - if (device.SeqWindow[bytePos] & (1 << bitPos)) != 0 { + if bytePos < uint64(len(client.SeqWindow)) { + if (client.SeqWindow[bytePos] & (1 << bitPos)) != 0 { return false // Already processed } } @@ -313,159 +313,159 @@ func ValidateSequenceAndNonce(device *database.Device, seq uint64, nonce string) return true } -func updateDeviceSequenceAndNonce(device *database.Device, seq uint64, nonce string) { +func updateClientSequenceAndNonce(client *database.Client, seq uint64, nonce string) { // Update nonce cache (keep last NonceCacheSize nonces) - device.NonceCache = append(device.NonceCache, nonce) - if len(device.NonceCache) > NonceCacheSize { - device.NonceCache = device.NonceCache[1:] // Remove oldest + client.NonceCache = append(client.NonceCache, nonce) + if len(client.NonceCache) > NonceCacheSize { + client.NonceCache = client.NonceCache[1:] // Remove oldest } // Update sequence window - if seq > device.CurrentSeq { + if seq > client.CurrentSeq { // New highest sequence - shift window - shift := seq - device.CurrentSeq + shift := seq - client.CurrentSeq if shift >= SequenceWindow { // Clear entire window - device.SeqWindow = make([]byte, 8) + client.SeqWindow = make([]byte, 8) } else { // Shift window right for range shift { - shiftWindowRight(device.SeqWindow) + shiftWindowRight(client.SeqWindow) } } - device.CurrentSeq = seq + client.CurrentSeq = seq // Mark current sequence as processed (position 0 in window) - device.SeqWindow[0] |= 1 + client.SeqWindow[0] |= 1 } else { // Mark this sequence as processed in the window - diff := device.CurrentSeq - seq + diff := client.CurrentSeq - seq windowPos := diff % SequenceWindow bytePos := windowPos / 8 bitPos := windowPos % 8 - if bytePos < uint64(len(device.SeqWindow)) { - device.SeqWindow[bytePos] |= (1 << bitPos) + if bytePos < uint64(len(client.SeqWindow)) { + client.SeqWindow[bytePos] |= (1 << bitPos) } } } -func UpdateDeviceState(userDB *userdb.UserDB, device *database.Device, seq uint64, nonce string) error { +func UpdateClientState(userDB *userdb.UserDB, client *database.Client, seq uint64, nonce string) error { // Update nonce cache (keep last NonceCacheSize nonces) - device.NonceCache = append(device.NonceCache, nonce) - if len(device.NonceCache) > NonceCacheSize { - device.NonceCache = device.NonceCache[1:] // Remove oldest + client.NonceCache = append(client.NonceCache, nonce) + if len(client.NonceCache) > NonceCacheSize { + client.NonceCache = client.NonceCache[1:] // Remove oldest } // Update sequence window - if seq > device.CurrentSeq { + if seq > client.CurrentSeq { // New highest sequence - shift window - shift := seq - device.CurrentSeq + shift := seq - client.CurrentSeq if shift >= SequenceWindow { // Clear entire window - device.SeqWindow = make([]byte, 8) + client.SeqWindow = make([]byte, 8) } else { // Shift window right for range shift { - shiftWindowRight(device.SeqWindow) + shiftWindowRight(client.SeqWindow) } } - device.CurrentSeq = seq + client.CurrentSeq = seq // Mark current sequence as processed (position 0 in window) - device.SeqWindow[0] |= 1 + client.SeqWindow[0] |= 1 } else { // Mark this sequence as processed in the window - diff := device.CurrentSeq - seq + diff := client.CurrentSeq - seq windowPos := diff % SequenceWindow bytePos := windowPos / 8 bitPos := windowPos % 8 - if bytePos < uint64(len(device.SeqWindow)) { - device.SeqWindow[bytePos] |= (1 << bitPos) + if bytePos < uint64(len(client.SeqWindow)) { + client.SeqWindow[bytePos] |= (1 << bitPos) } } // Update database - if err := userDB.UpdateDeviceSequence( - device.DeviceID, device.CurrentSeq, device.SeqWindow, device.NonceCache, + if err := userDB.UpdateClientSequence( + client.ClientID, client.CurrentSeq, client.SeqWindow, client.NonceCache, ); err != nil { - return fmt.Errorf("failed to update device sequence: %w", err) + return fmt.Errorf("failed to update client sequence: %w", err) } return nil } -// getDeviceMutex retrieves or creates a mutex for the given device ID -func (dm *DeviceMutexManager) getDeviceMutex(deviceID string) *deviceMutex { +// getClientMutex retrieves or creates a mutex for the given client ID +func (cm *ClientMutexManager) getClientMutex(clientID string) *clientMutex { // Try to load existing mutex - if value, exists := dm.mutexes.Load(deviceID); exists { - mutex := value.(*deviceMutex) + if value, exists := cm.mutexes.Load(clientID); exists { + mutex := value.(*clientMutex) mutex.lastUsed = time.Now() return mutex } // Create new mutex - newMutex := &deviceMutex{ + newMutex := &clientMutex{ lastUsed: time.Now(), - deviceID: deviceID, + clientID: clientID, } // Store and return the mutex (LoadOrStore handles race conditions) - actual, _ := dm.mutexes.LoadOrStore(deviceID, newMutex) - actualMutex := actual.(*deviceMutex) + actual, _ := cm.mutexes.LoadOrStore(clientID, newMutex) + actualMutex := actual.(*clientMutex) actualMutex.lastUsed = time.Now() return actualMutex } -// lockDevice acquires a lock for the specified device, preventing race conditions +// lockClient acquires a lock for the specified client, preventing race conditions // in authentication state updates. The returned function must be called to release the lock. -func (dm *DeviceMutexManager) lockDevice(deviceID string) func() { - mutex := dm.getDeviceMutex(deviceID) +func (cm *ClientMutexManager) lockClient(clientID string) func() { + mutex := cm.getClientMutex(clientID) mutex.mu.Lock() - + return func() { mutex.mu.Unlock() } } // cleanup removes unused mutexes to prevent memory leaks -func (dm *DeviceMutexManager) cleanup() { +func (cm *ClientMutexManager) cleanup() { now := time.Now() - dm.mutexes.Range(func(key, value interface{}) bool { - mutex := value.(*deviceMutex) + cm.mutexes.Range(func(key, value interface{}) bool { + mutex := value.(*clientMutex) if now.Sub(mutex.lastUsed) > MutexMaxIdle { - dm.mutexes.Delete(key) - log.Debug().Str("device_id", mutex.deviceID).Msg("cleaned up unused device mutex") + cm.mutexes.Delete(key) + log.Debug().Str("client_id", mutex.clientID).Msg("cleaned up unused client mutex") } return true }) } // startCleanupRoutine starts a background goroutine to periodically clean up unused mutexes -func (dm *DeviceMutexManager) startCleanupRoutine() { +func (cm *ClientMutexManager) startCleanupRoutine() { go func() { ticker := time.NewTicker(MutexCleanupInterval) defer ticker.Stop() - + for range ticker.C { - dm.cleanup() + cm.cleanup() } }() } -// GetDeviceMutex is a convenience function to get a mutex for a device -func GetDeviceMutex(deviceID string) *deviceMutex { - return globalDeviceMutexManager.getDeviceMutex(deviceID) +// GetClientMutex is a convenience function to get a mutex for a client +func GetClientMutex(clientID string) *clientMutex { + return globalClientMutexManager.getClientMutex(clientID) } -// LockDevice is a convenience function to lock a device -func LockDevice(deviceID string) func() { - return globalDeviceMutexManager.lockDevice(deviceID) +// LockClient is a convenience function to lock a client +func LockClient(clientID string) func() { + return globalClientMutexManager.lockClient(clientID) } func init() { - // Start cleanup routine for device mutexes - globalDeviceMutexManager.startCleanupRoutine() + // Start cleanup routine for client mutexes + globalClientMutexManager.startCleanupRoutine() } func shiftWindowRight(window []byte) { @@ -481,9 +481,9 @@ func IsAuthenticatedConnection(r *http.Request) bool { return !isLocalhost(r.RemoteAddr) } -func GetDeviceFromContext(ctx context.Context) *database.Device { - if device, ok := ctx.Value(deviceKey("device")).(*database.Device); ok { - return device +func GetClientFromContext(ctx context.Context) *database.Client { + if client, ok := ctx.Value(clientKey("client")).(*database.Client); ok { + return client } return nil } diff --git a/pkg/api/middleware/auth_test.go b/pkg/api/middleware/auth_test.go index 9aeef8ca2..35f10b867 100644 --- a/pkg/api/middleware/auth_test.go +++ b/pkg/api/middleware/auth_test.go @@ -81,9 +81,9 @@ func TestAuthMiddleware_RemoteRequiresAuth(t *testing.T) { t.Parallel() // Setup userDB := helpers.NewMockUserDBI() - // Mock the GetDeviceByAuthToken call to return an error for empty token + // Mock the GetClientByAuthToken call to return an error for empty token // The middleware calls this once per test case (we have 2 test cases) - userDB.On("GetDeviceByAuthToken", "").Return((*database.Device)(nil), assert.AnError).Times(2) + userDB.On("GetClientByAuthToken", "").Return((*database.Client)(nil), assert.AnError).Times(2) db := &database.Database{UserDB: userDB} @@ -124,9 +124,9 @@ func TestAuthMiddleware_EncryptedRequest(t *testing.T) { t.Parallel() // Create a mock device with known shared secret testSecret := []byte("test-secret-key-32-bytes-long-ok") - testDevice := &database.Device{ - DeviceID: "test-device-id", - DeviceName: "Test Device", + testDevice := &database.Client{ + ClientID: "test-device-id", + ClientName: "Test Device", AuthTokenHash: "test-token-hash", SharedSecret: testSecret, CurrentSeq: 0, @@ -137,7 +137,7 @@ func TestAuthMiddleware_EncryptedRequest(t *testing.T) { } userDB := helpers.NewMockUserDBI() - userDB.On("GetDeviceByAuthToken", "test-auth-token").Return(testDevice, nil) + userDB.On("GetClientByAuthToken", "test-auth-token").Return(testDevice, nil) userDB.On("UpdateDeviceSequence", "test-device-id", uint64(1), mock.AnythingOfType("[]uint8"), mock.AnythingOfType("[]string")).Return(nil) @@ -179,9 +179,9 @@ func TestAuthMiddleware_EncryptedRequest(t *testing.T) { assert.InDelta(t, float64(1), jsonRPC["id"], 0.001) // JSON unmarshals numbers as float64 // Verify device is in context - device := GetDeviceFromContext(r.Context()) + device := GetClientFromContext(r.Context()) assert.NotNil(t, device) - assert.Equal(t, "test-device-id", device.DeviceID) + assert.Equal(t, "test-device-id", device.ClientID) w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("authenticated success")) @@ -269,8 +269,8 @@ func TestValidateSequenceAndNonce_ReplayProtection(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - device := &database.Device{ - DeviceID: "test-device", + device := &database.Client{ + ClientID: "test-device", CurrentSeq: tt.currentSeq, SeqWindow: tt.seqWindow, NonceCache: tt.nonceCache, @@ -345,7 +345,7 @@ func TestAuthMiddleware_InvalidRequests(t *testing.T) { t.Parallel() userDB := helpers.NewMockUserDBI() // Mock empty auth token lookup - expect it to be called once for the missing auth token test - userDB.On("GetDeviceByAuthToken", "").Return((*database.Device)(nil), assert.AnError) + userDB.On("GetClientByAuthToken", "").Return((*database.Client)(nil), assert.AnError) db := &database.Database{UserDB: userDB} @@ -391,23 +391,23 @@ func TestAuthMiddleware_InvalidRequests(t *testing.T) { userDB.AssertExpectations(t) } -func TestGetDeviceFromContext(t *testing.T) { +func TestGetClientFromContext(t *testing.T) { t.Parallel() // Test with device in context - device := &database.Device{DeviceID: "test-device"} - ctx := context.WithValue(context.Background(), deviceKey("device"), device) + device := &database.Client{ClientID: "test-device"} + ctx := context.WithValue(context.Background(), clientKey("device"), device) - result := GetDeviceFromContext(ctx) + result := GetClientFromContext(ctx) assert.Equal(t, device, result) // Test with no device in context emptyCtx := context.Background() - result = GetDeviceFromContext(emptyCtx) + result = GetClientFromContext(emptyCtx) assert.Nil(t, result) // Test with wrong type in context - badCtx := context.WithValue(context.Background(), deviceKey("device"), "not-a-device") - result = GetDeviceFromContext(badCtx) + badCtx := context.WithValue(context.Background(), clientKey("device"), "not-a-device") + result = GetClientFromContext(badCtx) assert.Nil(t, result) } @@ -415,109 +415,109 @@ func TestGetDeviceFromContext(t *testing.T) { // prevents concurrent requests from bypassing replay protection func TestAuthMiddleware_ConcurrentRequests(t *testing.T) { t.Parallel() - + // This test verifies the mutex locking works correctly // We'll just test that concurrent access to the mutex manager is safe - + const numConcurrentRequests = 20 const deviceID = "test-device-concurrent" - + done := make(chan struct{}, numConcurrentRequests) var lockAcquired int32 - - for i := 0; i < numConcurrentRequests; i++ { + + for range numConcurrentRequests { go func() { defer func() { done <- struct{}{} }() - + // Acquire device lock - this should be thread-safe - unlockDevice := LockDevice(deviceID) - + unlockDevice := LockClient(deviceID) + // Critical section - only one goroutine should be here at a time current := atomic.AddInt32(&lockAcquired, 1) if current != 1 { t.Errorf("Race condition detected: %d goroutines in critical section", current) } - + // Simulate some work time.Sleep(1 * time.Millisecond) - + atomic.AddInt32(&lockAcquired, -1) unlockDevice() }() } // Wait for all requests to complete - for i := 0; i < numConcurrentRequests; i++ { + for range numConcurrentRequests { <-done } - + // Verify no race conditions occurred assert.Equal(t, int32(0), atomic.LoadInt32(&lockAcquired), "All locks should be released") } -// TestDeviceMutexManager_Cleanup verifies mutex cleanup works correctly -func TestDeviceMutexManager_Cleanup(t *testing.T) { +// TestClientMutexManager_Cleanup verifies mutex cleanup works correctly +func TestClientMutexManager_Cleanup(t *testing.T) { t.Parallel() - - dm := &DeviceMutexManager{} - + + dm := &ClientMutexManager{} + // Create some mutexes - mutex1 := dm.getDeviceMutex("device1") - mutex2 := dm.getDeviceMutex("device2") - mutex3 := dm.getDeviceMutex("device3") - + mutex1 := dm.getClientMutex("device1") + mutex2 := dm.getClientMutex("device2") + mutex3 := dm.getClientMutex("device3") + require.NotNil(t, mutex1) require.NotNil(t, mutex2) require.NotNil(t, mutex3) - + // Age some mutexes mutex1.lastUsed = time.Now().Add(-31 * time.Minute) // Should be cleaned mutex2.lastUsed = time.Now().Add(-20 * time.Minute) // Should remain - mutex3.lastUsed = time.Now() // Should remain - + mutex3.lastUsed = time.Now() // Should remain + // Run cleanup dm.cleanup() - + // Check that only old mutex was removed _, exists1 := dm.mutexes.Load("device1") _, exists2 := dm.mutexes.Load("device2") _, exists3 := dm.mutexes.Load("device3") - + assert.False(t, exists1, "Old mutex should be cleaned up") assert.True(t, exists2, "Recent mutex should remain") assert.True(t, exists3, "Current mutex should remain") } -// TestDeviceMutexManager_ConcurrentAccess verifies thread safety of mutex manager -func TestDeviceMutexManager_ConcurrentAccess(t *testing.T) { +// TestClientMutexManager_ConcurrentAccess verifies thread safety of mutex manager +func TestClientMutexManager_ConcurrentAccess(t *testing.T) { t.Parallel() - - dm := &DeviceMutexManager{} + + dm := &ClientMutexManager{} const numGoroutines = 50 const deviceID = "concurrent-test-device" - + done := make(chan struct{}, numGoroutines) - + // Launch multiple goroutines that get/create mutex for same device - for i := 0; i < numGoroutines; i++ { + for range numGoroutines { go func() { defer func() { done <- struct{}{} }() - - mutex := dm.getDeviceMutex(deviceID) + + mutex := dm.getClientMutex(deviceID) require.NotNil(t, mutex) - assert.Equal(t, deviceID, mutex.deviceID) + assert.Equal(t, deviceID, mutex.clientID) }() } - + // Wait for all goroutines to complete - for i := 0; i < numGoroutines; i++ { + for range numGoroutines { <-done } - + // Verify only one mutex was created for the device value, exists := dm.mutexes.Load(deviceID) assert.True(t, exists, "Mutex should exist") - - mutex := value.(*deviceMutex) - assert.Equal(t, deviceID, mutex.deviceID) + + mutex := value.(*clientMutex) + assert.Equal(t, deviceID, mutex.clientID) } diff --git a/pkg/api/pairing.go b/pkg/api/pairing.go index 683ef5b16..ca4af2676 100644 --- a/pkg/api/pairing.go +++ b/pkg/api/pairing.go @@ -54,7 +54,7 @@ type PairingManager struct { } type PairingInitiateRequest struct { - DeviceName string `json:"deviceName"` + ClientName string `json:"clientName"` } type PairingInitiateResponse struct { @@ -65,11 +65,11 @@ type PairingInitiateResponse struct { type PairingCompleteRequest struct { PairingToken string `json:"pairingToken"` Verifier string `json:"verifier"` - DeviceName string `json:"deviceName"` + ClientName string `json:"clientName"` } type PairingCompleteResponse struct { - DeviceID string `json:"deviceId"` + ClientID string `json:"clientId"` AuthToken string `json:"authToken"` SharedSecret string `json:"sharedSecret"` // Base64 encoded } @@ -200,7 +200,7 @@ func handlePairingComplete(db *database.Database) http.HandlerFunc { return } - if req.PairingToken == "" || req.Verifier == "" || req.DeviceName == "" { + if req.PairingToken == "" || req.Verifier == "" || req.ClientName == "" { http.Error(w, "Missing required fields", http.StatusBadRequest) return } @@ -219,7 +219,7 @@ func handlePairingComplete(db *database.Database) http.HandlerFunc { sharedSecret := make([]byte, 32) // 256 bits for AES-256 // Construct context-specific info string for domain separation - info := []byte("zaparoo-pairing-v1|" + req.PairingToken + "|" + req.DeviceName) + info := []byte("zaparoo-pairing-v1|" + req.PairingToken + "|" + req.ClientName) kdf := hkdf.New(sha256.New, combinedSecret, session.Salt, info) if _, err := kdf.Read(sharedSecret); err != nil { @@ -231,10 +231,10 @@ func handlePairingComplete(db *database.Database) http.HandlerFunc { // Generate auth token authToken := uuid.New().String() - // Create device in database - device, err := db.UserDB.CreateDevice(req.DeviceName, authToken, sharedSecret) + // Create client in database + client, err := db.UserDB.CreateClient(req.ClientName, authToken, sharedSecret) if err != nil { - log.Error().Err(err).Msg("failed to create device") + log.Error().Err(err).Msg("failed to create client") http.Error(w, "Internal server error", http.StatusInternalServerError) return } @@ -245,7 +245,7 @@ func handlePairingComplete(db *database.Database) http.HandlerFunc { pairingManager.mu.Unlock() response := PairingCompleteResponse{ - DeviceID: device.DeviceID, + ClientID: client.ClientID, AuthToken: authToken, SharedSecret: hex.EncodeToString(sharedSecret), } @@ -258,9 +258,9 @@ func handlePairingComplete(db *database.Database) http.HandlerFunc { } log.Info(). - Str("device_id", device.DeviceID). - Str("device_name", device.DeviceName). - Msg("device paired successfully") + Str("client_id", client.ClientID). + Str("client_name", client.ClientName). + Msg("client paired successfully") } } diff --git a/pkg/api/server.go b/pkg/api/server.go index 7e9fdd4eb..8917c296a 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -469,20 +469,20 @@ func broadcastNotifications( } // Check if session is authenticated - device, authenticated := s.Get("device") + client, authenticated := s.Get("client") if !authenticated { return false // Skip unauthenticated } // Encrypt notification for this session - deviceObj, ok := device.(*database.Device) + clientObj, ok := client.(*database.Client) if !ok { - log.Error().Msg("invalid device type in session") + log.Error().Msg("invalid client type in session") return false } - encrypted, iv, err := apimiddleware.EncryptPayload(data, deviceObj.SharedSecret) + encrypted, iv, err := apimiddleware.EncryptPayload(data, clientObj.SharedSecret) if err != nil { - log.Error().Err(err).Str("device_id", deviceObj.DeviceID).Msg("failed to encrypt notification") + log.Error().Err(err).Str("client_id", clientObj.ClientID).Msg("failed to encrypt notification") return false } @@ -494,14 +494,14 @@ func broadcastNotifications( encData, err := json.Marshal(encResponse) if err != nil { - log.Error().Err(err).Str("device_id", deviceObj.DeviceID). + log.Error().Err(err).Str("client_id", clientObj.ClientID). Msg("failed to marshal encrypted notification") return false } // Send encrypted data to this session if err := s.Write(encData); err != nil { - log.Error().Err(err).Str("device_id", deviceObj.DeviceID). + log.Error().Err(err).Str("client_id", clientObj.ClientID). Msg("failed to send encrypted notification") } @@ -602,7 +602,7 @@ func handleWSMessage( // Handle authentication for remote connections if !isLocal { - device, authenticated := session.Get("device") + client, authenticated := session.Get("client") if !authenticated { // First message must be authentication err := handleWSAuthentication(session, msg, db) @@ -614,16 +614,16 @@ func handleWSMessage( } // Decrypt message for authenticated remote connection - deviceObj, ok := device.(*database.Device) + clientObj, ok := client.(*database.Client) if !ok { - log.Error().Msg("invalid device type in session") + log.Error().Msg("invalid client type in session") err := sendWSError(session, uuid.Nil, JSONRPCErrorInternalError) if err != nil { log.Error().Err(err).Msg("failed to send WebSocket error") } return } - decryptedMsg, err := handleWSDecryption(session, msg, deviceObj, db) + decryptedMsg, err := handleWSDecryption(session, msg, clientObj, db) if err != nil { log.Error().Err(err).Msg("WebSocket decryption failed") err := sendWSError(session, uuid.Nil, JSONRPCErrorInvalidRequest) @@ -653,13 +653,13 @@ func handleWSMessage( } else { // Encrypt response for remote authenticated connections if !isLocal { - if device, authenticated := session.Get("device"); authenticated { - deviceObj, ok := device.(*database.Device) + if client, authenticated := session.Get("client"); authenticated { + clientObj, ok := client.(*database.Client) if !ok { - log.Error().Msg("invalid device type in session") + log.Error().Msg("invalid client type in session") return } - err := sendWSResponseEncrypted(session, id, resp, deviceObj) + err := sendWSResponseEncrypted(session, id, resp, clientObj) if err != nil { log.Error().Err(err).Msg("error sending encrypted response") } @@ -690,19 +690,19 @@ func handleWSAuthentication(session *melody.Session, msg []byte, db *database.Da return errors.New("missing auth token") } - // Validate auth token and get device - device, err := db.UserDB.GetDeviceByAuthToken(authMsg.AuthToken) + // Validate auth token and get client + client, err := db.UserDB.GetClientByAuthToken(authMsg.AuthToken) if err != nil { return fmt.Errorf("invalid auth token: %w", err) } - // Store device in session - session.Set("device", device) + // Store client in session + session.Set("client", client) // Send authentication success response authResponse := map[string]any{ "authenticated": true, - "device_id": device.DeviceID, + "client_id": client.ClientID, } responseData, _ := json.Marshal(authResponse) @@ -711,11 +711,11 @@ func handleWSAuthentication(session *melody.Session, msg []byte, db *database.Da return fmt.Errorf("failed to send auth response: %w", err) } - log.Debug().Str("device_id", device.DeviceID).Msg("WebSocket authenticated") + log.Debug().Str("client_id", client.ClientID).Msg("WebSocket authenticated") return nil } -func handleWSDecryption(_ *melody.Session, msg []byte, device *database.Device, db *database.Database) ([]byte, error) { +func handleWSDecryption(_ *melody.Session, msg []byte, client *database.Client, db *database.Database) ([]byte, error) { // Parse encrypted message var encMsg apimiddleware.EncryptedRequest if err := json.Unmarshal(msg, &encMsg); err != nil { @@ -723,7 +723,7 @@ func handleWSDecryption(_ *melody.Session, msg []byte, device *database.Device, } // Decrypt payload (we can reuse the middleware function by creating a temporary import) - decryptedPayload, err := apimiddleware.DecryptPayload(encMsg.Encrypted, encMsg.IV, device.SharedSecret) + decryptedPayload, err := apimiddleware.DecryptPayload(encMsg.Encrypted, encMsg.IV, client.SharedSecret) if err != nil { return nil, fmt.Errorf("decryption failed: %w", err) } @@ -734,29 +734,29 @@ func handleWSDecryption(_ *melody.Session, msg []byte, device *database.Device, return nil, fmt.Errorf("invalid decrypted payload: %w", unmarshalErr) } - // CRITICAL SECTION: Acquire device lock to prevent race conditions + // Acquire client lock to prevent race conditions // between validation and database update - unlockDevice := apimiddleware.LockDevice(device.DeviceID) - defer unlockDevice() + unlockClient := apimiddleware.LockClient(client.ClientID) + defer unlockClient() - // Re-fetch device state under lock to get latest sequence/nonce state - freshDevice, err := db.UserDB.GetDeviceByID(device.DeviceID) + // Re-fetch client state under lock to get latest sequence/nonce state + freshClient, err := db.UserDB.GetClientByID(client.ClientID) if err != nil { - return nil, fmt.Errorf("failed to re-fetch device under lock: %w", err) + return nil, fmt.Errorf("failed to re-fetch client under lock: %w", err) } - - // Validate sequence and nonce with fresh device state - if !apimiddleware.ValidateSequenceAndNonce(freshDevice, payload.Seq, payload.Nonce) { + + // Validate sequence and nonce with fresh client state + if !apimiddleware.ValidateSequenceAndNonce(freshClient, payload.Seq, payload.Nonce) { return nil, errors.New("invalid sequence or replay detected") } - // Update device state under lock + // Update client state under lock userDB, ok := db.UserDB.(*userdb.UserDB) if !ok { return nil, errors.New("failed to cast UserDB to concrete type") } - if updateErr := apimiddleware.UpdateDeviceState(userDB, freshDevice, payload.Seq, payload.Nonce); updateErr != nil { - return nil, fmt.Errorf("failed to update device state: %w", updateErr) + if updateErr := apimiddleware.UpdateClientState(userDB, freshClient, payload.Seq, payload.Nonce); updateErr != nil { + return nil, fmt.Errorf("failed to update client state: %w", updateErr) } // Return JSON-RPC payload without sequence/nonce @@ -776,7 +776,7 @@ func handleWSDecryption(_ *melody.Session, msg []byte, device *database.Device, return result, nil } -func sendWSResponseEncrypted(session *melody.Session, id uuid.UUID, result any, device *database.Device) error { +func sendWSResponseEncrypted(session *melody.Session, id uuid.UUID, result any, client *database.Client) error { // Create response object resp := models.ResponseObject{ JSONRPC: "2.0", @@ -791,7 +791,7 @@ func sendWSResponseEncrypted(session *melody.Session, id uuid.UUID, result any, } // Encrypt the response - encrypted, iv, err := apimiddleware.EncryptPayload(data, device.SharedSecret) + encrypted, iv, err := apimiddleware.EncryptPayload(data, client.SharedSecret) if err != nil { return fmt.Errorf("failed to encrypt response: %w", err) } diff --git a/pkg/cli/cli.go b/pkg/cli/cli.go index 60bc4f0a7..0ee72bb80 100644 --- a/pkg/cli/cli.go +++ b/pkg/cli/cli.go @@ -49,8 +49,8 @@ type Flags struct { Config *bool ShowLoader *string ShowPairingCode *bool - ListDevices *bool - RevokeDevice *string + ListClients *bool + RevokeClient *string ShowPicker *string Reload *bool } @@ -101,17 +101,17 @@ func SetupFlags() *Flags { ShowPairingCode: flag.Bool( "show-pairing-code", false, - "display QR code for device pairing", + "display QR code for client pairing", ), - ListDevices: flag.Bool( - "list-devices", + ListClients: flag.Bool( + "list-clients", false, - "list all paired devices", + "list all paired clients", ), - RevokeDevice: flag.String( - "revoke-device", + RevokeClient: flag.String( + "revoke-client", "", - "revoke access for device by ID", + "revoke access for client by ID", ), } } @@ -162,11 +162,11 @@ func (f *Flags) Post(cfg *config.Instance, pl platforms.Platform) { case *f.ShowPairingCode: handleShowPairingCode(cfg, pl) os.Exit(0) - case *f.ListDevices: - handleListDevices(cfg, pl) + case *f.ListClients: + handleListClients(cfg, pl) os.Exit(0) - case isFlagPassed("revoke-device"): - handleRevokeDevice(cfg, pl, *f.RevokeDevice) + case isFlagPassed("revoke-client"): + handleRevokeClient(cfg, pl, *f.RevokeClient) os.Exit(0) case isFlagPassed("write"): if *f.Write == "" { diff --git a/pkg/cli/devices.go b/pkg/cli/clients.go similarity index 83% rename from pkg/cli/devices.go rename to pkg/cli/clients.go index ea725435c..09251e97b 100644 --- a/pkg/cli/devices.go +++ b/pkg/cli/clients.go @@ -120,7 +120,7 @@ func handleShowPairingCode(cfg *config.Instance, pl platforms.Platform) { if _, err := fmt.Printf("Expires in: %d seconds\n", pairingResp.ExpiresIn); err != nil { log.Error().Err(err).Msg("failed to print expiration time") } - if _, err := fmt.Print("\nWaiting for device to pair... (Ctrl+C to cancel)\n"); err != nil { + if _, err := fmt.Print("\nWaiting for client to pair... (Ctrl+C to cancel)\n"); err != nil { log.Error().Err(err).Msg("failed to print waiting message") } @@ -128,8 +128,8 @@ func handleShowPairingCode(cfg *config.Instance, pl platforms.Platform) { select {} } -func handleListDevices(_ *config.Instance, pl platforms.Platform) { - // Open user database to list devices +func handleListClients(_ *config.Instance, pl platforms.Platform) { + // Open user database to list clients userDB, err := userdb.OpenUserDB(context.Background(), pl) if err != nil { log.Error().Err(err).Msg("failed to open user database") @@ -140,54 +140,54 @@ func handleListDevices(_ *config.Instance, pl platforms.Platform) { } defer func() { _ = userDB.Close() }() - devices, err := userDB.GetAllDevices() + clients, err := userDB.GetAllClients() if err != nil { - log.Error().Err(err).Msg("failed to get devices") - if _, writeErr := fmt.Fprintf(os.Stderr, "Error getting devices: %v\n", err); writeErr != nil { + log.Error().Err(err).Msg("failed to get clients") + if _, writeErr := fmt.Fprintf(os.Stderr, "Error getting clients: %v\n", err); writeErr != nil { log.Error().Err(writeErr).Msg("failed to write error message") } return } - if len(devices) == 0 { - if _, err := fmt.Println("No paired devices found."); err != nil { + if len(clients) == 0 { + if _, err := fmt.Println("No paired clients found."); err != nil { log.Error().Err(err).Msg("failed to print message") } return } - if _, err := fmt.Print("Paired devices:\n\n"); err != nil { + if _, err := fmt.Print("Paired clients:\n\n"); err != nil { log.Error().Err(err).Msg("failed to print header") } - if _, err := fmt.Printf("%-36s %-20s %-10s %s\n", "Device ID", "Name", "Sequence", "Last Seen"); err != nil { + if _, err := fmt.Printf("%-36s %-20s %-10s %s\n", "Client ID", "Name", "Sequence", "Last Seen"); err != nil { log.Error().Err(err).Msg("failed to print column headers") } if _, err := fmt.Printf("%s\n", strings.Repeat("-", 80)); err != nil { log.Error().Err(err).Msg("failed to print separator") } - for i := range devices { - device := &devices[i] + for i := range clients { + client := &clients[i] if _, err := fmt.Printf("%-36s %-20s %-10d %s\n", - device.DeviceID, - device.DeviceName, - device.CurrentSeq, - device.LastSeen.Format("2006-01-02 15:04:05"), + client.ClientID, + client.ClientName, + client.CurrentSeq, + client.LastSeen.Format("2006-01-02 15:04:05"), ); err != nil { - log.Error().Err(err).Msg("failed to print device info") + log.Error().Err(err).Msg("failed to print client info") } } } -func handleRevokeDevice(_ *config.Instance, pl platforms.Platform, deviceID string) { - if deviceID == "" { - if _, err := fmt.Fprint(os.Stderr, "Error: device ID is required\n"); err != nil { +func handleRevokeClient(_ *config.Instance, pl platforms.Platform, clientID string) { + if clientID == "" { + if _, err := fmt.Fprint(os.Stderr, "Error: client ID is required\n"); err != nil { log.Error().Err(err).Msg("failed to write error message") } os.Exit(1) } - // Open user database to revoke device + // Open user database to revoke client userDB, err := userdb.OpenUserDB(context.Background(), pl) if err != nil { log.Error().Err(err).Msg("failed to open user database") @@ -198,16 +198,16 @@ func handleRevokeDevice(_ *config.Instance, pl platforms.Platform, deviceID stri } defer func() { _ = userDB.Close() }() - err = userDB.DeleteDevice(deviceID) + err = userDB.DeleteClient(clientID) if err != nil { - log.Error().Err(err).Msg("failed to delete device") - if _, writeErr := fmt.Fprintf(os.Stderr, "Error deleting device: %v\n", err); writeErr != nil { + log.Error().Err(err).Msg("failed to delete client") + if _, writeErr := fmt.Fprintf(os.Stderr, "Error deleting client: %v\n", err); writeErr != nil { log.Error().Err(writeErr).Msg("failed to write error message") } return } - if _, err := fmt.Printf("Device %s has been revoked successfully.\n", deviceID); err != nil { + if _, err := fmt.Printf("Client %s has been revoked successfully.\n", clientID); err != nil { log.Error().Err(err).Msg("failed to print success message") } } diff --git a/pkg/database/database.go b/pkg/database/database.go index e42c1a463..73f9a4a83 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -69,11 +69,11 @@ type System struct { DBID int64 } -type Device struct { +type Client struct { CreatedAt time.Time `json:"createdAt"` LastSeen time.Time `json:"lastSeen"` - DeviceID string `json:"deviceId"` - DeviceName string `json:"deviceName"` + ClientID string `json:"clientId"` + ClientName string `json:"clientName"` AuthTokenHash string `json:"-"` SharedSecret []byte `json:"-"` SeqWindow []byte `json:"-"` @@ -166,12 +166,12 @@ type UserDBI interface { GetZapLinkHost(host string) (bool, bool, error) UpdateZapLinkCache(url string, zapscript string) error GetZapLinkCache(url string) (string, error) - CreateDevice(deviceName, authToken string, sharedSecret []byte) (*Device, error) - GetDeviceByAuthToken(authToken string) (*Device, error) - GetDeviceByID(deviceID string) (*Device, error) - UpdateDeviceSequence(deviceID string, newSeq uint64, seqWindow []byte, nonceCache []string) error - GetAllDevices() ([]Device, error) - DeleteDevice(deviceID string) error + CreateClient(clientName, authToken string, sharedSecret []byte) (*Client, error) + GetClientByAuthToken(authToken string) (*Client, error) + GetClientByID(clientID string) (*Client, error) + UpdateClientSequence(clientID string, newSeq uint64, seqWindow []byte, nonceCache []string) error + GetAllClients() ([]Client, error) + DeleteClient(clientID string) error } type MediaDBI interface { diff --git a/pkg/database/userdb/devices.go b/pkg/database/userdb/clients.go similarity index 55% rename from pkg/database/userdb/devices.go rename to pkg/database/userdb/clients.go index 187e17768..362a6fd7c 100644 --- a/pkg/database/userdb/devices.go +++ b/pkg/database/userdb/clients.go @@ -32,18 +32,18 @@ import ( "github.com/google/uuid" ) -func (db *UserDB) CreateDevice(deviceName, authToken string, sharedSecret []byte) (*database.Device, error) { +func (db *UserDB) CreateClient(clientName, authToken string, sharedSecret []byte) (*database.Client, error) { if db.sql == nil { return nil, ErrNullSQL } - deviceID := uuid.New().String() + clientID := uuid.New().String() authTokenHash := hashAuthToken(authToken) now := time.Now().Unix() - device := &database.Device{ - DeviceID: deviceID, - DeviceName: deviceName, + client := &database.Client{ + ClientID: clientID, + ClientName: clientName, AuthTokenHash: authTokenHash, SharedSecret: sharedSecret, CurrentSeq: 0, @@ -53,36 +53,36 @@ func (db *UserDB) CreateDevice(deviceName, authToken string, sharedSecret []byte LastSeen: time.Unix(now, 0), } - nonceCacheJSON, err := json.Marshal(device.NonceCache) + nonceCacheJSON, err := json.Marshal(client.NonceCache) if err != nil { return nil, fmt.Errorf("failed to marshal nonce cache: %w", err) } query := ` - INSERT INTO devices (device_id, device_name, auth_token_hash, shared_secret, + INSERT INTO clients (client_id, client_name, auth_token_hash, shared_secret, current_seq, seq_window, nonce_cache, created_at, last_seen) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) ` _, err = db.sql.ExecContext(context.Background(), query, - device.DeviceID, - device.DeviceName, - device.AuthTokenHash, - device.SharedSecret, - device.CurrentSeq, - device.SeqWindow, + client.ClientID, + client.ClientName, + client.AuthTokenHash, + client.SharedSecret, + client.CurrentSeq, + client.SeqWindow, string(nonceCacheJSON), now, now, ) if err != nil { - return nil, fmt.Errorf("failed to create device: %w", err) + return nil, fmt.Errorf("failed to create client: %w", err) } - return device, nil + return client, nil } -func (db *UserDB) GetDeviceByAuthToken(authToken string) (*database.Device, error) { +func (db *UserDB) GetClientByAuthToken(authToken string) (*database.Client, error) { if db.sql == nil { return nil, ErrNullSQL } @@ -90,85 +90,85 @@ func (db *UserDB) GetDeviceByAuthToken(authToken string) (*database.Device, erro authTokenHash := hashAuthToken(authToken) query := ` - SELECT device_id, device_name, auth_token_hash, shared_secret, current_seq, + SELECT client_id, client_name, auth_token_hash, shared_secret, current_seq, seq_window, nonce_cache, created_at, last_seen - FROM devices + FROM clients WHERE auth_token_hash = ? ` - var device database.Device + var client database.Client var nonceCacheJSON string var createdAt, lastSeen int64 err := db.sql.QueryRowContext(context.Background(), query, authTokenHash).Scan( - &device.DeviceID, - &device.DeviceName, - &device.AuthTokenHash, - &device.SharedSecret, - &device.CurrentSeq, - &device.SeqWindow, + &client.ClientID, + &client.ClientName, + &client.AuthTokenHash, + &client.SharedSecret, + &client.CurrentSeq, + &client.SeqWindow, &nonceCacheJSON, &createdAt, &lastSeen, ) if err != nil { - return nil, fmt.Errorf("device not found: %w", err) + return nil, fmt.Errorf("client not found: %w", err) } - device.CreatedAt = time.Unix(createdAt, 0) - device.LastSeen = time.Unix(lastSeen, 0) + client.CreatedAt = time.Unix(createdAt, 0) + client.LastSeen = time.Unix(lastSeen, 0) - err = json.Unmarshal([]byte(nonceCacheJSON), &device.NonceCache) + err = json.Unmarshal([]byte(nonceCacheJSON), &client.NonceCache) if err != nil { - device.NonceCache = make([]string, 0) // Fallback to empty cache + client.NonceCache = make([]string, 0) // Fallback to empty cache } - return &device, nil + return &client, nil } -func (db *UserDB) GetDeviceByID(deviceID string) (*database.Device, error) { +func (db *UserDB) GetClientByID(clientID string) (*database.Client, error) { if db.sql == nil { return nil, ErrNullSQL } query := ` - SELECT device_id, device_name, auth_token_hash, shared_secret, current_seq, + SELECT client_id, client_name, auth_token_hash, shared_secret, current_seq, seq_window, nonce_cache, created_at, last_seen - FROM devices - WHERE device_id = ? + FROM clients + WHERE client_id = ? ` - var device database.Device + var client database.Client var nonceCacheJSON string var createdAt, lastSeen int64 - err := db.sql.QueryRowContext(context.Background(), query, deviceID).Scan( - &device.DeviceID, - &device.DeviceName, - &device.AuthTokenHash, - &device.SharedSecret, - &device.CurrentSeq, - &device.SeqWindow, + err := db.sql.QueryRowContext(context.Background(), query, clientID).Scan( + &client.ClientID, + &client.ClientName, + &client.AuthTokenHash, + &client.SharedSecret, + &client.CurrentSeq, + &client.SeqWindow, &nonceCacheJSON, &createdAt, &lastSeen, ) if err != nil { - return nil, fmt.Errorf("device not found: %w", err) + return nil, fmt.Errorf("client not found: %w", err) } - device.CreatedAt = time.Unix(createdAt, 0) - device.LastSeen = time.Unix(lastSeen, 0) + client.CreatedAt = time.Unix(createdAt, 0) + client.LastSeen = time.Unix(lastSeen, 0) - err = json.Unmarshal([]byte(nonceCacheJSON), &device.NonceCache) + err = json.Unmarshal([]byte(nonceCacheJSON), &client.NonceCache) if err != nil { - device.NonceCache = make([]string, 0) // Fallback to empty cache + client.NonceCache = make([]string, 0) // Fallback to empty cache } - return &device, nil + return &client, nil } -func (db *UserDB) UpdateDeviceSequence(deviceID string, newSeq uint64, seqWindow []byte, nonceCache []string) error { +func (db *UserDB) UpdateClientSequence(clientID string, newSeq uint64, seqWindow []byte, nonceCache []string) error { if db.sql == nil { return ErrNullSQL } @@ -179,87 +179,87 @@ func (db *UserDB) UpdateDeviceSequence(deviceID string, newSeq uint64, seqWindow } query := ` - UPDATE devices + UPDATE clients SET current_seq = ?, seq_window = ?, nonce_cache = ?, last_seen = ? - WHERE device_id = ? + WHERE client_id = ? ` _, err = db.sql.ExecContext( - context.Background(), query, newSeq, seqWindow, string(nonceCacheJSON), time.Now().Unix(), deviceID, + context.Background(), query, newSeq, seqWindow, string(nonceCacheJSON), time.Now().Unix(), clientID, ) if err != nil { - return fmt.Errorf("failed to update device sequence: %w", err) + return fmt.Errorf("failed to update client sequence: %w", err) } return nil } -func (db *UserDB) GetAllDevices() ([]database.Device, error) { +func (db *UserDB) GetAllClients() ([]database.Client, error) { if db.sql == nil { return nil, ErrNullSQL } query := ` - SELECT device_id, device_name, auth_token_hash, shared_secret, current_seq, + SELECT client_id, client_name, auth_token_hash, shared_secret, current_seq, seq_window, nonce_cache, created_at, last_seen - FROM devices + FROM clients ORDER BY last_seen DESC ` rows, err := db.sql.QueryContext(context.Background(), query) if err != nil { - return nil, fmt.Errorf("failed to query devices: %w", err) + return nil, fmt.Errorf("failed to query clients: %w", err) } defer func() { _ = rows.Close() }() - devices := make([]database.Device, 0) + clients := make([]database.Client, 0) for rows.Next() { - var device database.Device + var client database.Client var nonceCacheJSON string var createdAt, lastSeen int64 scanErr := rows.Scan( - &device.DeviceID, - &device.DeviceName, - &device.AuthTokenHash, - &device.SharedSecret, - &device.CurrentSeq, - &device.SeqWindow, + &client.ClientID, + &client.ClientName, + &client.AuthTokenHash, + &client.SharedSecret, + &client.CurrentSeq, + &client.SeqWindow, &nonceCacheJSON, &createdAt, &lastSeen, ) if scanErr != nil { - return nil, fmt.Errorf("failed to scan device row: %w", scanErr) + return nil, fmt.Errorf("failed to scan client row: %w", scanErr) } - device.CreatedAt = time.Unix(createdAt, 0) - device.LastSeen = time.Unix(lastSeen, 0) + client.CreatedAt = time.Unix(createdAt, 0) + client.LastSeen = time.Unix(lastSeen, 0) - err = json.Unmarshal([]byte(nonceCacheJSON), &device.NonceCache) + err = json.Unmarshal([]byte(nonceCacheJSON), &client.NonceCache) if err != nil { - device.NonceCache = make([]string, 0) // Fallback to empty cache + client.NonceCache = make([]string, 0) // Fallback to empty cache } - devices = append(devices, device) + clients = append(clients, client) } if err = rows.Err(); err != nil { - return nil, fmt.Errorf("error reading device rows: %w", err) + return nil, fmt.Errorf("error reading client rows: %w", err) } - return devices, nil + return clients, nil } -func (db *UserDB) DeleteDevice(deviceID string) error { +func (db *UserDB) DeleteClient(clientID string) error { if db.sql == nil { return ErrNullSQL } - query := `DELETE FROM devices WHERE device_id = ?` - result, err := db.sql.ExecContext(context.Background(), query, deviceID) + query := `DELETE FROM clients WHERE client_id = ?` + result, err := db.sql.ExecContext(context.Background(), query, clientID) if err != nil { - return fmt.Errorf("failed to delete device: %w", err) + return fmt.Errorf("failed to delete client: %w", err) } rowsAffected, err := result.RowsAffected() @@ -268,7 +268,7 @@ func (db *UserDB) DeleteDevice(deviceID string) error { } if rowsAffected == 0 { - return errors.New("device not found") + return errors.New("client not found") } return nil diff --git a/pkg/database/userdb/migrations/20250912231204_devices_auth.sql b/pkg/database/userdb/migrations/20250912231204_clients_auth.sql similarity index 64% rename from pkg/database/userdb/migrations/20250912231204_devices_auth.sql rename to pkg/database/userdb/migrations/20250912231204_clients_auth.sql index 3bddc26ad..b3380c6d7 100644 --- a/pkg/database/userdb/migrations/20250912231204_devices_auth.sql +++ b/pkg/database/userdb/migrations/20250912231204_clients_auth.sql @@ -1,8 +1,8 @@ -- +goose Up -- +goose StatementBegin -CREATE TABLE devices ( - device_id TEXT PRIMARY KEY, - device_name TEXT NOT NULL, +CREATE TABLE clients ( + client_id TEXT PRIMARY KEY, + client_name TEXT NOT NULL, auth_token_hash TEXT NOT NULL UNIQUE, shared_secret BLOB NOT NULL, current_seq INTEGER DEFAULT 0, @@ -12,11 +12,11 @@ CREATE TABLE devices ( last_seen INTEGER NOT NULL DEFAULT (strftime('%s', 'now')) ); -CREATE INDEX idx_devices_auth_token ON devices(auth_token_hash); -CREATE INDEX idx_devices_last_seen ON devices(last_seen); +CREATE INDEX idx_clients_auth_token ON clients(auth_token_hash); +CREATE INDEX idx_clients_last_seen ON clients(last_seen); -- +goose StatementEnd -- +goose Down -- +goose StatementBegin -DROP TABLE devices; +DROP TABLE clients; -- +goose StatementEnd \ No newline at end of file diff --git a/pkg/testing/helpers/db_mocks.go b/pkg/testing/helpers/db_mocks.go index 607e07bdd..2607e8d64 100644 --- a/pkg/testing/helpers/db_mocks.go +++ b/pkg/testing/helpers/db_mocks.go @@ -244,77 +244,77 @@ func (m *MockUserDBI) GetZapLinkCache(url string) (string, error) { return args.String(0), args.Error(1) } -// Device authentication methods -func (m *MockUserDBI) CreateDevice(deviceName, authToken string, sharedSecret []byte) (*database.Device, error) { - args := m.Called(deviceName, authToken, sharedSecret) - if device, ok := args.Get(0).(*database.Device); ok { +// Client authentication methods +func (m *MockUserDBI) CreateClient(clientName, authToken string, sharedSecret []byte) (*database.Client, error) { + args := m.Called(clientName, authToken, sharedSecret) + if client, ok := args.Get(0).(*database.Client); ok { if err := args.Error(1); err != nil { - return device, fmt.Errorf("mock UserDBI create device failed: %w", err) + return client, fmt.Errorf("mock UserDBI create client failed: %w", err) } - return device, nil + return client, nil } if err := args.Error(1); err != nil { - return nil, fmt.Errorf("mock UserDBI create device failed: %w", err) + return nil, fmt.Errorf("mock UserDBI create client failed: %w", err) } return nil, ErrMockNotConfigured } -func (m *MockUserDBI) GetDeviceByAuthToken(authToken string) (*database.Device, error) { +func (m *MockUserDBI) GetClientByAuthToken(authToken string) (*database.Client, error) { args := m.Called(authToken) - if device, ok := args.Get(0).(*database.Device); ok { + if client, ok := args.Get(0).(*database.Client); ok { if err := args.Error(1); err != nil { - return device, fmt.Errorf("mock UserDBI get device by auth token failed: %w", err) + return client, fmt.Errorf("mock UserDBI get client by auth token failed: %w", err) } - return device, nil + return client, nil } if err := args.Error(1); err != nil { - return nil, fmt.Errorf("mock UserDBI get device by auth token failed: %w", err) + return nil, fmt.Errorf("mock UserDBI get client by auth token failed: %w", err) } return nil, ErrMockNotConfigured } -func (m *MockUserDBI) GetDeviceByID(deviceID string) (*database.Device, error) { - args := m.Called(deviceID) - if device, ok := args.Get(0).(*database.Device); ok { +func (m *MockUserDBI) GetClientByID(clientID string) (*database.Client, error) { + args := m.Called(clientID) + if client, ok := args.Get(0).(*database.Client); ok { if err := args.Error(1); err != nil { - return device, fmt.Errorf("mock UserDBI get device by ID failed: %w", err) + return client, fmt.Errorf("mock UserDBI get client by ID failed: %w", err) } - return device, nil + return client, nil } if err := args.Error(1); err != nil { - return nil, fmt.Errorf("mock UserDBI get device by ID failed: %w", err) + return nil, fmt.Errorf("mock UserDBI get client by ID failed: %w", err) } return nil, ErrMockNotConfigured } -func (m *MockUserDBI) UpdateDeviceSequence( - deviceID string, newSeq uint64, seqWindow []byte, nonceCache []string, +func (m *MockUserDBI) UpdateClientSequence( + clientID string, newSeq uint64, seqWindow []byte, nonceCache []string, ) error { - args := m.Called(deviceID, newSeq, seqWindow, nonceCache) + args := m.Called(clientID, newSeq, seqWindow, nonceCache) if err := args.Error(0); err != nil { - return fmt.Errorf("mock UserDBI update device sequence failed: %w", err) + return fmt.Errorf("mock UserDBI update client sequence failed: %w", err) } return nil } -func (m *MockUserDBI) GetAllDevices() ([]database.Device, error) { +func (m *MockUserDBI) GetAllClients() ([]database.Client, error) { args := m.Called() - if devices, ok := args.Get(0).([]database.Device); ok { + if clients, ok := args.Get(0).([]database.Client); ok { if err := args.Error(1); err != nil { - return devices, fmt.Errorf("mock UserDBI get all devices failed: %w", err) + return clients, fmt.Errorf("mock UserDBI get all clients failed: %w", err) } - return devices, nil + return clients, nil } if err := args.Error(1); err != nil { - return nil, fmt.Errorf("mock UserDBI get all devices failed: %w", err) + return nil, fmt.Errorf("mock UserDBI get all clients failed: %w", err) } return nil, nil } -func (m *MockUserDBI) DeleteDevice(deviceID string) error { - args := m.Called(deviceID) +func (m *MockUserDBI) DeleteClient(clientID string) error { + args := m.Called(clientID) if err := args.Error(0); err != nil { - return fmt.Errorf("mock UserDBI delete device failed: %w", err) + return fmt.Errorf("mock UserDBI delete client failed: %w", err) } return nil } From bc81c3a34b926e26964debf3652e1fe63c926194 Mon Sep 17 00:00:00 2001 From: Callan Barrett Date: Sat, 13 Sep 2025 10:24:13 +0800 Subject: [PATCH 3/5] improvements --- docs/api/authentication.md | 618 ++++++++++++++++++++++++++++++++ go.mod | 5 +- go.sum | 6 + pkg/api/middleware/auth.go | 111 ++---- pkg/api/middleware/auth_test.go | 149 +++++++- pkg/api/pairing.go | 111 +++++- pkg/api/server.go | 164 ++++++++- pkg/config/config.go | 6 + pkg/database/userdb/clients.go | 56 ++- 9 files changed, 1085 insertions(+), 141 deletions(-) create mode 100644 docs/api/authentication.md diff --git a/docs/api/authentication.md b/docs/api/authentication.md new file mode 100644 index 000000000..35b580536 --- /dev/null +++ b/docs/api/authentication.md @@ -0,0 +1,618 @@ +# Zaparoo Core Secure API Authentication + +This document provides a complete guide for developers implementing clients that connect to Zaparoo Core using the secure authentication layer. + +## Table of Contents + +- [Overview](#overview) +- [Security Model](#security-model) +- [Client Pairing](#client-pairing) +- [HTTP API Usage](#http-api-usage) +- [WebSocket API Usage](#websocket-api-usage) +- [Example Implementations](#example-implementations) +- [Error Handling](#error-handling) +- [Security Best Practices](#security-best-practices) + +## Overview + +Zaparoo Core implements a secure authentication system for remote client connections while maintaining backward compatibility for localhost connections. The security model uses: + +- **AES-256-GCM** encryption for all remote API communications +- **Argon2id + HKDF** key derivation for client pairing +- **Sequence numbers + nonces** for replay attack protection +- **Per-client state management** for concurrent access safety + +### Connection Types + +| Connection Type | Authentication Required | Encryption Required | +| -------------------------- | ----------------------- | ------------------- | +| Localhost (127.0.0.1, ::1) | ❌ No | ❌ No | +| Remote (all other IPs) | ✅ Yes | ✅ Yes | + +## Security Model + +### Encryption Format + +All remote API requests must use the following encrypted format: + +```json +{ + "encrypted": "base64-encoded-ciphertext", + "iv": "base64-encoded-initialization-vector", + "authToken": "client-auth-token" +} +``` + +### Decrypted Payload Format + +The decrypted payload contains the standard JSON-RPC request plus security fields: + +```json +{ + "jsonrpc": "2.0", + "method": "system.version", + "id": 1, + "params": {}, + "seq": 123, + "nonce": "unique-request-nonce" +} +``` + +## Client Pairing + +Before making authenticated requests, clients must complete a pairing process to obtain shared encryption keys. + +### Step 1: Initiate Pairing + +**Request:** + +```http +POST /api/pair/initiate +Content-Type: application/json + +{ + "clientName": "MyApp_v1.0" +} +``` + +**Response:** + +```json +{ + "pairingToken": "550e8400-e29b-41d4-a716-446655440000", + "expiresIn": 300 +} +``` + +### Step 2: Complete Pairing + +The client must obtain a verification code through an out-of-band method (QR code, manual entry, etc.) and complete pairing. + +**Request:** + +```http +POST /api/pair/complete +Content-Type: application/json + +{ + "pairingToken": "550e8400-e29b-41d4-a716-446655440000", + "verifier": "user-provided-verification-code", + "clientName": "MyApp_v1.0" +} +``` + +**Response:** + +```json +{ + "clientId": "client-uuid-here", + "authToken": "auth-token-uuid-here", + "sharedSecret": "64-character-hex-encoded-key" +} +``` + +### Client Name Requirements + +- Length: 1-100 characters +- Characters: letters, numbers, underscore, dash only (`^[a-zA-Z0-9_-]+$`) +- Must be unique per client instance + +## HTTP API Usage + +### Making Authenticated Requests + +1. **Prepare the JSON-RPC payload** with sequence number and nonce +2. **Encrypt the payload** using AES-256-GCM +3. **Send the encrypted request** to the API endpoint + +### Example: Get System Version + +```javascript +// 1. Prepare JSON-RPC payload +const payload = { + jsonrpc: "2.0", + method: "system.version", + id: 1, + seq: ++sequenceNumber, + nonce: generateNonce(), +}; + +// 2. Encrypt payload +const encrypted = await encryptPayload(JSON.stringify(payload), sharedSecret); + +// 3. Send request +const response = await fetch("http://zaparoo-host:7497/api/v0.1", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + encrypted: encrypted.ciphertext, + iv: encrypted.iv, + authToken: authToken, + }), +}); + +// 4. Decrypt response +const encryptedResponse = await response.json(); +const decryptedResponse = await decryptPayload( + encryptedResponse.encrypted, + encryptedResponse.iv, + sharedSecret +); +``` + +## WebSocket API Usage + +WebSocket connections require a two-step authentication process: + +### Step 1: Connect and Authenticate + +```javascript +const ws = new WebSocket("ws://zaparoo-host:7497/api/v0.1"); + +ws.onopen = () => { + // Send authentication message + ws.send( + JSON.stringify({ + authToken: authToken, + }) + ); +}; + +ws.onmessage = (event) => { + const message = JSON.parse(event.data); + + if (message.authenticated) { + console.log("WebSocket authenticated successfully"); + // Now ready to send encrypted requests + } else { + // Handle encrypted response + handleEncryptedMessage(message); + } +}; +``` + +### Step 2: Send Encrypted Requests + +```javascript +async function sendRequest(method, params = {}) { + const payload = { + jsonrpc: "2.0", + method: method, + id: generateId(), + params: params, + seq: ++sequenceNumber, + nonce: generateNonce(), + }; + + const encrypted = await encryptPayload(JSON.stringify(payload), sharedSecret); + + ws.send( + JSON.stringify({ + encrypted: encrypted.ciphertext, + iv: encrypted.iv, + authToken: authToken, + }) + ); +} + +// Example usage +await sendRequest("system.version"); +await sendRequest("search.games", { query: "mario" }); +``` + +### Handling Encrypted Responses + +```javascript +async function handleEncryptedMessage(encryptedMessage) { + if (encryptedMessage.encrypted && encryptedMessage.iv) { + const decrypted = await decryptPayload( + encryptedMessage.encrypted, + encryptedMessage.iv, + sharedSecret + ); + + const response = JSON.parse(decrypted); + console.log("Received response:", response); + } +} +``` + +## Example Implementations + +### JavaScript/Node.js Client + +```javascript +import crypto from "crypto"; + +class ZaparooSecureClient { + constructor(host, port = 7497) { + this.baseUrl = `http://${host}:${port}`; + this.wsUrl = `ws://${host}:${port}`; + this.sequenceNumber = 0; + this.authToken = null; + this.sharedSecret = null; + } + + // Pairing process + async pair(clientName, verifier) { + // Step 1: Initiate pairing + const initResponse = await fetch(`${this.baseUrl}/api/pair/initiate`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ clientName }), + }); + + const { pairingToken } = await initResponse.json(); + + // Step 2: Complete pairing (verifier obtained out-of-band) + const completeResponse = await fetch(`${this.baseUrl}/api/pair/complete`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + pairingToken, + verifier, + clientName, + }), + }); + + const result = await completeResponse.json(); + this.authToken = result.authToken; + this.sharedSecret = Buffer.from(result.sharedSecret, "hex"); + + return result; + } + + // Encryption utilities + async encryptPayload(data) { + const iv = crypto.randomBytes(12); + const cipher = crypto.createCipherGCM("aes-256-gcm", this.sharedSecret); + cipher.setAAD(Buffer.alloc(0)); + + const encrypted = Buffer.concat([ + cipher.update(data, "utf8"), + cipher.final(), + ]); + + const authTag = cipher.getAuthTag(); + const ciphertext = Buffer.concat([encrypted, authTag]); + + return { + ciphertext: ciphertext.toString("base64"), + iv: iv.toString("base64"), + }; + } + + async decryptPayload(encryptedB64, ivB64) { + const encrypted = Buffer.from(encryptedB64, "base64"); + const iv = Buffer.from(ivB64, "base64"); + + const authTag = encrypted.slice(-16); + const ciphertext = encrypted.slice(0, -16); + + const decipher = crypto.createDecipherGCM("aes-256-gcm", this.sharedSecret); + decipher.setAuthTag(authTag); + decipher.setAAD(Buffer.alloc(0)); + + const decrypted = Buffer.concat([ + decipher.update(ciphertext), + decipher.final(), + ]); + + return decrypted.toString("utf8"); + } + + generateNonce() { + return crypto.randomBytes(16).toString("hex"); + } + + // HTTP API request + async request(method, params = {}) { + const payload = { + jsonrpc: "2.0", + method, + id: Date.now(), + params, + seq: ++this.sequenceNumber, + nonce: this.generateNonce(), + }; + + const encrypted = await this.encryptPayload(JSON.stringify(payload)); + + const response = await fetch(`${this.baseUrl}/api/v0.1`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + encrypted: encrypted.ciphertext, + iv: encrypted.iv, + authToken: this.authToken, + }), + }); + + const encryptedResponse = await response.json(); + const decryptedData = await this.decryptPayload( + encryptedResponse.encrypted, + encryptedResponse.iv + ); + + return JSON.parse(decryptedData); + } +} + +// Usage example +const client = new ZaparooSecureClient("192.168.1.100"); + +// Pair with server (verifier obtained from QR code/user input) +await client.pair("MyApp_v1.0", "verification-code-from-qr"); + +// Make API requests +const version = await client.request("system.version"); +const games = await client.request("search.games", { query: "zelda" }); +``` + +### Python Client + +```python +import asyncio +import json +import secrets +from cryptography.hazmat.primitives.ciphers.aead import AESGCM +import aiohttp +import websockets + +class ZaparooSecureClient: + def __init__(self, host, port=7497): + self.base_url = f"http://{host}:{port}" + self.ws_url = f"ws://{host}:{port}" + self.sequence_number = 0 + self.auth_token = None + self.shared_secret = None + + async def pair(self, client_name, verifier): + async with aiohttp.ClientSession() as session: + # Initiate pairing + async with session.post( + f"{self.base_url}/api/pair/initiate", + json={"clientName": client_name} + ) as response: + result = await response.json() + pairing_token = result["pairingToken"] + + # Complete pairing + async with session.post( + f"{self.base_url}/api/pair/complete", + json={ + "pairingToken": pairing_token, + "verifier": verifier, + "clientName": client_name + } + ) as response: + result = await response.json() + self.auth_token = result["authToken"] + self.shared_secret = bytes.fromhex(result["sharedSecret"]) + return result + + def encrypt_payload(self, data): + aesgcm = AESGCM(self.shared_secret) + nonce = secrets.token_bytes(12) + ciphertext = aesgcm.encrypt(nonce, data.encode(), None) + + return { + "ciphertext": ciphertext.hex(), + "iv": nonce.hex() + } + + def decrypt_payload(self, encrypted_hex, iv_hex): + aesgcm = AESGCM(self.shared_secret) + ciphertext = bytes.fromhex(encrypted_hex) + nonce = bytes.fromhex(iv_hex) + + plaintext = aesgcm.decrypt(nonce, ciphertext, None) + return plaintext.decode() + + def generate_nonce(self): + return secrets.token_hex(16) + + async def request(self, method, params=None): + if params is None: + params = {} + + payload = { + "jsonrpc": "2.0", + "method": method, + "id": 1, + "params": params, + "seq": self.sequence_number + 1, + "nonce": self.generate_nonce() + } + self.sequence_number += 1 + + encrypted = self.encrypt_payload(json.dumps(payload)) + + async with aiohttp.ClientSession() as session: + async with session.post( + f"{self.base_url}/api/v0.1", + json={ + "encrypted": encrypted["ciphertext"], + "iv": encrypted["iv"], + "authToken": self.auth_token + } + ) as response: + encrypted_response = await response.json() + decrypted = self.decrypt_payload( + encrypted_response["encrypted"], + encrypted_response["iv"] + ) + return json.loads(decrypted) + +# Usage +client = ZaparooSecureClient("192.168.1.100") +await client.pair("MyPythonApp", "verification-code") +result = await client.request("system.version") +``` + +## Error Handling + +### Common Error Responses + +| Error | HTTP Status | Description | +| ---------------------- | ----------- | -------------------------------- | +| Invalid auth token | 401 | Token not found or expired | +| Invalid request format | 400 | Malformed JSON or missing fields | +| Decryption failed | 400 | Invalid encryption or wrong key | +| Invalid sequence | 400 | Replay attack detected | +| Method not allowed | 405 | Wrong HTTP method | + +### Example Error Response + +```json +{ + "jsonrpc": "2.0", + "id": 1, + "error": { + "code": -32600, + "message": "Invalid Request" + } +} +``` + +### Client-Side Error Handling + +```javascript +try { + const result = await client.request("system.version"); + console.log(result); +} catch (error) { + if (error.response?.status === 401) { + console.log("Authentication failed - re-pair required"); + await client.pair(clientName, newVerifier); + } else if (error.response?.status === 400) { + console.log("Request error:", error.response.data); + } else { + console.log("Network error:", error.message); + } +} +``` + +## Security Best Practices + +### For Client Developers + +1. **Secure Key Storage**: Store shared secrets securely (keychain, encrypted storage) +2. **Sequence Management**: Persist sequence numbers across app restarts +3. **Nonce Uniqueness**: Generate cryptographically random nonces +4. **Error Handling**: Don't expose sensitive information in logs +5. **Network Security**: Use TLS for additional transport security +6. **Token Rotation**: Implement re-pairing for long-running applications + +### Example Secure Storage (Node.js) + +```javascript +import keytar from "keytar"; + +class SecureStorage { + static async storeCredentials( + clientId, + authToken, + sharedSecret, + sequenceNumber + ) { + await keytar.setPassword("zaparoo", `${clientId}-auth`, authToken); + await keytar.setPassword("zaparoo", `${clientId}-secret`, sharedSecret); + await keytar.setPassword( + "zaparoo", + `${clientId}-seq`, + sequenceNumber.toString() + ); + } + + static async loadCredentials(clientId) { + const authToken = await keytar.getPassword("zaparoo", `${clientId}-auth`); + const sharedSecret = await keytar.getPassword( + "zaparoo", + `${clientId}-secret` + ); + const sequenceNumber = parseInt( + (await keytar.getPassword("zaparoo", `${clientId}-seq`)) || "0" + ); + + return { authToken, sharedSecret, sequenceNumber }; + } +} +``` + +### Sequence Number Management + +```javascript +class SequenceManager { + constructor(clientId) { + this.clientId = clientId; + this.sequenceNumber = this.loadSequenceNumber(); + } + + getNext() { + this.sequenceNumber++; + this.saveSequenceNumber(); + return this.sequenceNumber; + } + + loadSequenceNumber() { + // Load from persistent storage + return parseInt( + localStorage.getItem(`zaparoo-seq-${this.clientId}`) || "0" + ); + } + + saveSequenceNumber() { + localStorage.setItem( + `zaparoo-seq-${this.clientId}`, + this.sequenceNumber.toString() + ); + } +} +``` + +## API Endpoints Reference + +### Available Endpoints + +| Endpoint | Authentication | Description | +| ------------------------- | ----------------- | ------------------------ | +| `POST /api/pair/initiate` | None | Start pairing process | +| `POST /api/pair/complete` | None | Complete pairing process | +| `POST /api/v0.1` | Required (remote) | JSON-RPC API | +| `WS /api/v0.1` | Required (remote) | WebSocket API | + +### Supported Methods + +Once authenticated, all standard Zaparoo Core JSON-RPC methods are available: + +- `system.version` - Get system version +- `system.heartbeat` - Health check +- `search.games` - Search game database +- `launch.game` - Launch a game +- `media.scan` - Trigger media scan +- And more... (see main API documentation) + +--- + +**Note**: This secure authentication layer is only required for remote connections. Localhost connections (127.0.0.1, ::1) continue to work without authentication for development and local tool access. diff --git a/go.mod b/go.mod index 9a19ffedb..4d1dedcb7 100755 --- a/go.mod +++ b/go.mod @@ -36,6 +36,7 @@ require ( go.bug.st/serial v1.6.4 go.etcd.io/bbolt v1.4.0 golang.design/x/clipboard v0.7.0 + golang.org/x/crypto v0.38.0 golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6 golang.org/x/sys v0.35.0 golang.org/x/text v0.26.0 @@ -54,6 +55,8 @@ require ( github.com/geoffgarside/ber v1.1.0 // indirect github.com/go-toast/toast v0.0.0-20190211030409-01e6764cf0a4 // indirect github.com/godbus/dbus/v5 v5.1.0 // indirect + github.com/gorilla/csrf v1.7.3 // indirect + github.com/gorilla/securecookie v1.1.2 // indirect github.com/hashicorp/go-uuid v1.0.3 // indirect github.com/jcmturner/aescts/v2 v2.0.0 // indirect github.com/jcmturner/dnsutils/v2 v2.0.0 // indirect @@ -73,8 +76,8 @@ require ( github.com/sethvargo/go-retry v0.3.0 // indirect github.com/stretchr/objx v0.5.2 // indirect github.com/tadvi/systray v0.0.0-20190226123456-11a2b8fa57af // indirect + github.com/unrolled/secure v1.17.0 // indirect go.uber.org/multierr v1.11.0 // indirect - golang.org/x/crypto v0.38.0 // indirect golang.org/x/exp/shiny v0.0.0-20250408133849-7e4ce0ab07d0 // indirect golang.org/x/image v0.20.0 // indirect golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a // indirect diff --git a/go.sum b/go.sum index c2abad4e7..bbe7e1339 100644 --- a/go.sum +++ b/go.sum @@ -59,8 +59,12 @@ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/csrf v1.7.3 h1:BHWt6FTLZAb2HtWT5KDBf6qgpZzvtbp9QWDRKZMXJC0= +github.com/gorilla/csrf v1.7.3/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk= github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= +github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= +github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI= github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= @@ -143,6 +147,8 @@ github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOf github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tadvi/systray v0.0.0-20190226123456-11a2b8fa57af h1:6yITBqGTE2lEeTPG04SN9W+iWHCRyHqlVYILiSXziwk= github.com/tadvi/systray v0.0.0-20190226123456-11a2b8fa57af/go.mod h1:4F09kP5F+am0jAwlQLddpoMDM+iewkxxt6nxUQ5nq5o= +github.com/unrolled/secure v1.17.0 h1:Io7ifFgo99Bnh0J7+Q+qcMzWM6kaDPCA5FroFZEdbWU= +github.com/unrolled/secure v1.17.0/go.mod h1:BmF5hyM6tXczk3MpQkFf1hpKSRqCyhqcbiQtiAF7+40= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.bug.st/serial v1.6.4 h1:7FmqNPgVp3pu2Jz5PoPtbZ9jJO5gnEnZIvnI1lzve8A= go.bug.st/serial v1.6.4/go.mod h1:nofMJxTeNVny/m6+KaafC6vJGj3miwQZ6vW4BZUGJPI= diff --git a/pkg/api/middleware/auth.go b/pkg/api/middleware/auth.go index 817b144b0..7518c0e02 100644 --- a/pkg/api/middleware/auth.go +++ b/pkg/api/middleware/auth.go @@ -36,7 +36,6 @@ import ( "time" "github.com/ZaparooProject/zaparoo-core/v2/pkg/database" - "github.com/ZaparooProject/zaparoo-core/v2/pkg/database/userdb" "github.com/rs/zerolog/log" ) @@ -118,8 +117,8 @@ func AuthMiddleware(db *database.Database) func(http.Handler) http.Handler { return } - // Validate auth token and get client - client, err := db.UserDB.GetClientByAuthToken(encReq.AuthToken) + // First, validate auth token and get client for initial validation + initialClient, err := db.UserDB.GetClientByAuthToken(encReq.AuthToken) if err != nil { tokenStr := "empty" if len(encReq.AuthToken) >= 8 { @@ -127,13 +126,17 @@ func AuthMiddleware(db *database.Database) func(http.Handler) http.Handler { } else if encReq.AuthToken != "" { tokenStr = encReq.AuthToken } - log.Error().Err(err).Str("token", tokenStr).Msg("invalid auth token") + log.Warn().Err(err). + Str("token", tokenStr). + Str("remote_addr", r.RemoteAddr). + Str("user_agent", r.Header.Get("User-Agent")). + Msg("SECURITY: invalid auth token - potential attack") http.Error(w, "Invalid auth token", http.StatusUnauthorized) return } - // Decrypt payload - decryptedPayload, err := DecryptPayload(encReq.Encrypted, encReq.IV, client.SharedSecret) + // Decrypt payload with initial client data + decryptedPayload, err := DecryptPayload(encReq.Encrypted, encReq.IV, initialClient.SharedSecret) if err != nil { log.Error().Err(err).Msg("failed to decrypt payload") http.Error(w, "Decryption failed", http.StatusBadRequest) @@ -148,46 +151,44 @@ func AuthMiddleware(db *database.Database) func(http.Handler) http.Handler { return } - // Acquire client lock to prevent race conditions - // between validation and database update - unlockClient := LockClient(client.ClientID) + // CRITICAL: Acquire client lock BEFORE any sequence/nonce validation + // to prevent race conditions between concurrent requests + unlockClient := LockClient(initialClient.ClientID) defer unlockClient() // Re-fetch client state under lock to get latest sequence/nonce state - freshClient, err := db.UserDB.GetClientByAuthToken(encReq.AuthToken) + client, err := db.UserDB.GetClientByAuthToken(encReq.AuthToken) if err != nil { log.Error().Err(err).Msg("failed to re-fetch client under lock") http.Error(w, "Internal server error", http.StatusInternalServerError) return } - // Validate sequence number and nonce with fresh client state - if !ValidateSequenceAndNonce(freshClient, payload.Seq, payload.Nonce) { + // Validate sequence number and nonce with locked client state + if !ValidateSequenceAndNonce(client, payload.Seq, payload.Nonce) { log.Warn(). - Str("client_id", freshClient.ClientID). + Str("client_id", client.ClientID). Uint64("seq", payload.Seq). Str("nonce", payload.Nonce). - Msg("invalid sequence or replay attack detected") + Str("remote_addr", r.RemoteAddr). + Str("user_agent", r.Header.Get("User-Agent")). + Msg("SECURITY: replay attack detected") http.Error(w, "Invalid sequence or replay detected", http.StatusBadRequest) return } // Update client state (sequence and nonce cache) - updatedClient := *freshClient - updateClientSequenceAndNonce(&updatedClient, payload.Seq, payload.Nonce) + updateClientSequenceAndNonce(client, payload.Seq, payload.Nonce) // Save to database (still under lock) if updateErr := db.UserDB.UpdateClientSequence( - updatedClient.ClientID, updatedClient.CurrentSeq, updatedClient.SeqWindow, updatedClient.NonceCache, + client.ClientID, client.CurrentSeq, client.SeqWindow, client.NonceCache, ); updateErr != nil { log.Error().Err(updateErr).Msg("failed to update client state") http.Error(w, "Internal server error", http.StatusInternalServerError) return } - // Update the client pointer for context (use fresh client with updates) - client = &updatedClient - // Replace request body with decrypted JSON-RPC payload originalPayload := map[string]any{ "jsonrpc": payload.JSONRPC, @@ -213,11 +214,12 @@ func AuthMiddleware(db *database.Database) func(http.Handler) http.Handler { ctx := context.WithValue(r.Context(), clientKey("client"), client) r = r.WithContext(ctx) - log.Debug(). + log.Info(). Str("client_id", client.ClientID). Str("method", payload.Method). Uint64("seq", payload.Seq). - Msg("authenticated request processed") + Str("remote_addr", r.RemoteAddr). + Msg("SECURITY: authenticated request processed") next.ServeHTTP(w, r) }) @@ -350,51 +352,6 @@ func updateClientSequenceAndNonce(client *database.Client, seq uint64, nonce str } } -func UpdateClientState(userDB *userdb.UserDB, client *database.Client, seq uint64, nonce string) error { - // Update nonce cache (keep last NonceCacheSize nonces) - client.NonceCache = append(client.NonceCache, nonce) - if len(client.NonceCache) > NonceCacheSize { - client.NonceCache = client.NonceCache[1:] // Remove oldest - } - - // Update sequence window - if seq > client.CurrentSeq { - // New highest sequence - shift window - shift := seq - client.CurrentSeq - if shift >= SequenceWindow { - // Clear entire window - client.SeqWindow = make([]byte, 8) - } else { - // Shift window right - for range shift { - shiftWindowRight(client.SeqWindow) - } - } - client.CurrentSeq = seq - - // Mark current sequence as processed (position 0 in window) - client.SeqWindow[0] |= 1 - } else { - // Mark this sequence as processed in the window - diff := client.CurrentSeq - seq - windowPos := diff % SequenceWindow - bytePos := windowPos / 8 - bitPos := windowPos % 8 - - if bytePos < uint64(len(client.SeqWindow)) { - client.SeqWindow[bytePos] |= (1 << bitPos) - } - } - - // Update database - if err := userDB.UpdateClientSequence( - client.ClientID, client.CurrentSeq, client.SeqWindow, client.NonceCache, - ); err != nil { - return fmt.Errorf("failed to update client sequence: %w", err) - } - return nil -} - // getClientMutex retrieves or creates a mutex for the given client ID func (cm *ClientMutexManager) getClientMutex(clientID string) *clientMutex { // Try to load existing mutex @@ -441,14 +398,20 @@ func (cm *ClientMutexManager) cleanup() { }) } -// startCleanupRoutine starts a background goroutine to periodically clean up unused mutexes -func (cm *ClientMutexManager) startCleanupRoutine() { +// StartCleanupRoutine starts a background goroutine to periodically clean up unused mutexes +func (cm *ClientMutexManager) StartCleanupRoutine(ctx context.Context) { go func() { ticker := time.NewTicker(MutexCleanupInterval) defer ticker.Stop() - for range ticker.C { - cm.cleanup() + for { + select { + case <-ctx.Done(): + log.Debug().Msg("client mutex cleanup routine stopped") + return + case <-ticker.C: + cm.cleanup() + } } }() } @@ -463,9 +426,9 @@ func LockClient(clientID string) func() { return globalClientMutexManager.lockClient(clientID) } -func init() { - // Start cleanup routine for client mutexes - globalClientMutexManager.startCleanupRoutine() +// StartGlobalMutexCleanup starts the global mutex cleanup routine +func StartGlobalMutexCleanup(ctx context.Context) { + globalClientMutexManager.StartCleanupRoutine(ctx) } func shiftWindowRight(window []byte) { diff --git a/pkg/api/middleware/auth_test.go b/pkg/api/middleware/auth_test.go index 35f10b867..a621afb86 100644 --- a/pkg/api/middleware/auth_test.go +++ b/pkg/api/middleware/auth_test.go @@ -138,7 +138,7 @@ func TestAuthMiddleware_EncryptedRequest(t *testing.T) { userDB := helpers.NewMockUserDBI() userDB.On("GetClientByAuthToken", "test-auth-token").Return(testDevice, nil) - userDB.On("UpdateDeviceSequence", "test-device-id", uint64(1), + userDB.On("UpdateClientSequence", "test-device-id", uint64(1), mock.AnythingOfType("[]uint8"), mock.AnythingOfType("[]string")).Return(nil) db := &database.Database{UserDB: userDB} @@ -264,6 +264,36 @@ func TestValidateSequenceAndNonce_ReplayProtection(t *testing.T) { expectedResult: true, description: "sequence within sliding window should pass", }, + { + name: "sequence at window boundary", + currentSeq: 64, + seqWindow: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + nonceCache: []string{}, + newSeq: 1, // Exactly at window boundary (64 behind) + newNonce: "nonce1", + expectedResult: false, + description: "sequence exactly at window boundary should be rejected", + }, + { + name: "sequence already processed", + currentSeq: 10, + seqWindow: []byte{0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, // Bit 2 set (seq 8) + nonceCache: []string{}, + newSeq: 8, // Already processed + newNonce: "nonce8", + expectedResult: false, + description: "already processed sequence should be rejected", + }, + { + name: "large sequence jump", + currentSeq: 5, + seqWindow: []byte{0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + nonceCache: []string{}, + newSeq: 100, // Large jump forward + newNonce: "nonce100", + expectedResult: true, + description: "large sequence jump should be accepted", + }, } for _, tt := range tests { @@ -393,20 +423,20 @@ func TestAuthMiddleware_InvalidRequests(t *testing.T) { func TestGetClientFromContext(t *testing.T) { t.Parallel() - // Test with device in context - device := &database.Client{ClientID: "test-device"} - ctx := context.WithValue(context.Background(), clientKey("device"), device) + // Test with client in context + client := &database.Client{ClientID: "test-device"} + ctx := context.WithValue(context.Background(), clientKey("client"), client) result := GetClientFromContext(ctx) - assert.Equal(t, device, result) + assert.Equal(t, client, result) - // Test with no device in context + // Test with no client in context emptyCtx := context.Background() result = GetClientFromContext(emptyCtx) assert.Nil(t, result) // Test with wrong type in context - badCtx := context.WithValue(context.Background(), clientKey("device"), "not-a-device") + badCtx := context.WithValue(context.Background(), clientKey("client"), "not-a-client") result = GetClientFromContext(badCtx) assert.Nil(t, result) } @@ -521,3 +551,108 @@ func TestClientMutexManager_ConcurrentAccess(t *testing.T) { mutex := value.(*clientMutex) assert.Equal(t, deviceID, mutex.clientID) } + +// TestShiftWindowRight verifies the bit manipulation for the sliding window +func TestShiftWindowRight(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input []byte + expected []byte + }{ + { + name: "shift zeros", + input: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + expected: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + }, + { + name: "shift ones", + input: []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, + expected: []byte{0x7F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, + }, + { + name: "shift single bit", + input: []byte{0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + expected: []byte{0x00, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + }, + { + name: "shift pattern", + input: []byte{0xAA, 0x55, 0xAA, 0x55, 0xAA, 0x55, 0xAA, 0x55}, + expected: []byte{0x55, 0x2A, 0xD5, 0x2A, 0xD5, 0x2A, 0xD5, 0x2A}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + // Make a copy to avoid modifying the test input + window := make([]byte, len(tt.input)) + copy(window, tt.input) + + shiftWindowRight(window) + + assert.Equal(t, tt.expected, window, "bit shift should match expected result") + }) + } +} + +// TestUpdateClientSequenceAndNonce verifies the sequence window updates correctly +func TestUpdateClientSequenceAndNonce(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + initialSeq uint64 + initialWindow []byte + newSeq uint64 + nonce string + expectedSeq uint64 + checkBitSet bool + bitPosition uint64 + }{ + { + name: "increment sequence", + initialSeq: 5, + initialWindow: []byte{0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + newSeq: 6, + nonce: "nonce6", + expectedSeq: 6, + checkBitSet: true, + bitPosition: 0, // Latest sequence should be at position 0 + }, + { + name: "large jump forward", + initialSeq: 5, + initialWindow: []byte{0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + newSeq: 100, + nonce: "nonce100", + expectedSeq: 100, + checkBitSet: true, + bitPosition: 0, // Latest sequence should be at position 0 after window clear + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + client := &database.Client{ + ClientID: "test-device", + CurrentSeq: tt.initialSeq, + SeqWindow: make([]byte, len(tt.initialWindow)), + NonceCache: []string{}, + } + copy(client.SeqWindow, tt.initialWindow) + + updateClientSequenceAndNonce(client, tt.newSeq, tt.nonce) + + assert.Equal(t, tt.expectedSeq, client.CurrentSeq, "sequence should be updated") + assert.Contains(t, client.NonceCache, tt.nonce, "nonce should be added to cache") + + if tt.checkBitSet { + // Check that the bit at position 0 is set (latest sequence) + assert.NotZero(t, client.SeqWindow[0]&1, "bit 0 should be set for latest sequence") + } + }) + } +} diff --git a/pkg/api/pairing.go b/pkg/api/pairing.go index ca4af2676..78697f983 100644 --- a/pkg/api/pairing.go +++ b/pkg/api/pairing.go @@ -20,18 +20,23 @@ package api import ( + "context" "crypto/rand" "crypto/sha256" "encoding/hex" "encoding/json" "fmt" "net/http" + "regexp" + "runtime" + "strings" "sync" "time" "github.com/ZaparooProject/zaparoo-core/v2/pkg/database" "github.com/google/uuid" "github.com/rs/zerolog/log" + "golang.org/x/crypto/argon2" "golang.org/x/crypto/hkdf" ) @@ -78,9 +83,29 @@ var pairingManager = &PairingManager{ sessions: make(map[string]*PairingSession), } -func init() { - // Start cleanup routine - go pairingManager.cleanup() +var clientNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`) + +func validateClientName(name string) error { + name = strings.TrimSpace(name) + + if len(name) == 0 { + return fmt.Errorf("client name cannot be empty") + } + + if len(name) > 100 { + return fmt.Errorf("client name too long (max 100 characters)") + } + + if !clientNameRegex.MatchString(name) { + return fmt.Errorf("client name contains invalid characters (only letters, numbers, underscore, and dash allowed)") + } + + return nil +} + +// StartPairingCleanup starts the pairing session cleanup routine +func StartPairingCleanup(ctx context.Context) { + go pairingManager.cleanup(ctx) } func (pm *PairingManager) createSession() (*PairingSession, error) { @@ -141,19 +166,32 @@ func (pm *PairingManager) consumeSession(token string) (*PairingSession, bool) { return session, true } -func (pm *PairingManager) cleanup() { +func (pm *PairingManager) cleanup(ctx context.Context) { ticker := time.NewTicker(1 * time.Minute) defer ticker.Stop() - for range ticker.C { - pm.mu.Lock() - now := time.Now() - for token, session := range pm.sessions { - if now.Sub(session.CreatedAt) > PairingTokenExpiry { - delete(pm.sessions, token) + for { + select { + case <-ctx.Done(): + log.Debug().Msg("pairing manager cleanup routine stopped") + return + case <-ticker.C: + pm.mu.Lock() + now := time.Now() + for token, session := range pm.sessions { + if now.Sub(session.CreatedAt) > PairingTokenExpiry { + // Zero sensitive data before deletion + for i := range session.Challenge { + session.Challenge[i] = 0 + } + for i := range session.Salt { + session.Salt[i] = 0 + } + delete(pm.sessions, token) + } } + pm.mu.Unlock() } - pm.mu.Unlock() } } @@ -205,6 +243,17 @@ func handlePairingComplete(db *database.Database) http.HandlerFunc { return } + // Validate client name + if err := validateClientName(req.ClientName); err != nil { + log.Warn(). + Str("client_name", req.ClientName). + Str("remote_addr", r.RemoteAddr). + Err(err). + Msg("SECURITY: invalid client name provided") + http.Error(w, fmt.Sprintf("Invalid client name: %s", err.Error()), http.StatusBadRequest) + return + } + // Get and consume pairing session session, exists := pairingManager.consumeSession(req.PairingToken) if !exists { @@ -212,10 +261,36 @@ func handlePairingComplete(db *database.Database) http.HandlerFunc { return } - // Derive shared secret using HKDF (challenge + verifier) - combinedSecret := make([]byte, len(session.Challenge)+len(req.Verifier)) + // First, use Argon2id for strong key derivation from verifier + // Using security parameters recommended by OWASP (2023) + verifierKey := argon2.IDKey( + []byte(req.Verifier), // password + session.Salt, // salt (32 bytes) + 3, // time parameter (iterations) + 64*1024, // memory parameter (64 MB) + 4, // parallelism parameter + 32, // key length (32 bytes) + ) + // Ensure verifierKey is zeroed after use + defer func() { + for i := range verifierKey { + verifierKey[i] = 0 + } + runtime.KeepAlive(verifierKey) // Prevent compiler optimization + }() + + // Then combine with challenge using HKDF for domain separation + combinedSecret := make([]byte, len(session.Challenge)+len(verifierKey)) copy(combinedSecret, session.Challenge) - copy(combinedSecret[len(session.Challenge):], req.Verifier) + copy(combinedSecret[len(session.Challenge):], verifierKey) + // Ensure combinedSecret is zeroed after use + defer func() { + for i := range combinedSecret { + combinedSecret[i] = 0 + } + runtime.KeepAlive(combinedSecret) // Prevent compiler optimization + }() + sharedSecret := make([]byte, 32) // 256 bits for AES-256 // Construct context-specific info string for domain separation @@ -239,6 +314,14 @@ func handlePairingComplete(db *database.Database) http.HandlerFunc { return } + // Zero sensitive session data before removing + for i := range session.Challenge { + session.Challenge[i] = 0 + } + for i := range session.Salt { + session.Salt[i] = 0 + } + // Remove session from manager pairingManager.mu.Lock() delete(pairingManager.sessions, req.PairingToken) diff --git a/pkg/api/server.go b/pkg/api/server.go index 8917c296a..7b1951f6c 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -22,6 +22,7 @@ package api import ( "bytes" "context" + "crypto/rand" "encoding/json" "errors" "fmt" @@ -43,7 +44,6 @@ import ( "github.com/ZaparooProject/zaparoo-core/v2/pkg/assets" "github.com/ZaparooProject/zaparoo-core/v2/pkg/config" "github.com/ZaparooProject/zaparoo-core/v2/pkg/database" - "github.com/ZaparooProject/zaparoo-core/v2/pkg/database/userdb" "github.com/ZaparooProject/zaparoo-core/v2/pkg/helpers" "github.com/ZaparooProject/zaparoo-core/v2/pkg/platforms" "github.com/ZaparooProject/zaparoo-core/v2/pkg/service/state" @@ -52,8 +52,10 @@ import ( "github.com/go-chi/chi/v5/middleware" "github.com/go-chi/cors" "github.com/google/uuid" + "github.com/gorilla/csrf" "github.com/olahol/melody" "github.com/rs/zerolog/log" + "github.com/unrolled/secure" ) var allowedOrigins = []string{ @@ -607,12 +609,27 @@ func handleWSMessage( // First message must be authentication err := handleWSAuthentication(session, msg, db) if err != nil { - log.Error().Err(err).Msg("WebSocket authentication failed") + log.Warn().Err(err). + Str("remote_addr", session.Request.RemoteAddr). + Str("user_agent", session.Request.Header.Get("User-Agent")). + Msg("SECURITY: WebSocket authentication failed") _ = session.Close() } return } + // Check session timeout + if isSessionExpired(session) { + log.Warn(). + Str("remote_addr", session.Request.RemoteAddr). + Msg("SECURITY: WebSocket session expired - closing connection") + _ = session.Close() + return + } + + // Update last activity time + session.Set("auth_time", time.Now()) + // Decrypt message for authenticated remote connection clientObj, ok := client.(*database.Client) if !ok { @@ -680,6 +697,26 @@ type WSAuthMessage struct { AuthToken string `json:"authToken"` } +type CSRFTokenResponse struct { + CSRFToken string `json:"csrfToken"` +} + +const sessionTimeout = 30 * time.Minute + +func isSessionExpired(session *melody.Session) bool { + authTime, exists := session.Get("auth_time") + if !exists { + return true // No auth time means expired + } + + timestamp, ok := authTime.(time.Time) + if !ok { + return true // Invalid timestamp means expired + } + + return time.Since(timestamp) > sessionTimeout +} + func handleWSAuthentication(session *melody.Session, msg []byte, db *database.Database) error { var authMsg WSAuthMessage if err := json.Unmarshal(msg, &authMsg); err != nil { @@ -696,8 +733,9 @@ func handleWSAuthentication(session *melody.Session, msg []byte, db *database.Da return fmt.Errorf("invalid auth token: %w", err) } - // Store client in session + // Store client and auth timestamp in session session.Set("client", client) + session.Set("auth_time", time.Now()) // Send authentication success response authResponse := map[string]any{ @@ -711,7 +749,10 @@ func handleWSAuthentication(session *melody.Session, msg []byte, db *database.Da return fmt.Errorf("failed to send auth response: %w", err) } - log.Debug().Str("client_id", client.ClientID).Msg("WebSocket authenticated") + log.Info(). + Str("client_id", client.ClientID). + Str("remote_addr", session.Request.RemoteAddr). + Msg("SECURITY: WebSocket authenticated successfully") return nil } @@ -750,12 +791,48 @@ func handleWSDecryption(_ *melody.Session, msg []byte, client *database.Client, return nil, errors.New("invalid sequence or replay detected") } - // Update client state under lock - userDB, ok := db.UserDB.(*userdb.UserDB) - if !ok { - return nil, errors.New("failed to cast UserDB to concrete type") + // Update client state with sequence and nonce + // Update nonce cache (keep last NonceCacheSize nonces) + freshClient.NonceCache = append(freshClient.NonceCache, payload.Nonce) + if len(freshClient.NonceCache) > 100 { // NonceCacheSize + freshClient.NonceCache = freshClient.NonceCache[1:] // Remove oldest + } + + // Update sequence window + if payload.Seq > freshClient.CurrentSeq { + // New highest sequence - shift window + shift := payload.Seq - freshClient.CurrentSeq + if shift >= 64 { // SequenceWindow + // Clear entire window + freshClient.SeqWindow = make([]byte, 8) + } else { + // Shift window right + for range shift { + for i := len(freshClient.SeqWindow) - 1; i > 0; i-- { + freshClient.SeqWindow[i] = (freshClient.SeqWindow[i] << 1) | (freshClient.SeqWindow[i-1] >> 7) + } + freshClient.SeqWindow[0] <<= 1 + } + } + freshClient.CurrentSeq = payload.Seq + + // Mark current sequence as processed (position 0 in window) + freshClient.SeqWindow[0] |= 1 + } else { + // Mark this sequence as processed in the window + diff := freshClient.CurrentSeq - payload.Seq + windowPos := diff % 64 // SequenceWindow + bytePos := windowPos / 8 + bitPos := windowPos % 8 + + if bytePos < uint64(len(freshClient.SeqWindow)) { + freshClient.SeqWindow[bytePos] |= (1 << bitPos) + } } - if updateErr := apimiddleware.UpdateClientState(userDB, freshClient, payload.Seq, payload.Nonce); updateErr != nil { + + if updateErr := db.UserDB.UpdateClientSequence( + freshClient.ClientID, freshClient.CurrentSeq, freshClient.SeqWindow, freshClient.NonceCache, + ); updateErr != nil { return nil, fmt.Errorf("failed to update client state: %w", updateErr) } @@ -918,14 +995,39 @@ func Start( rateLimiter.StartCleanup(st.GetContext()) r.Use(apimiddleware.HTTPRateLimitMiddleware(rateLimiter)) + + // Start global mutex cleanup routine with proper context + apimiddleware.StartGlobalMutexCleanup(st.GetContext()) + + // Start pairing session cleanup routine with proper context + StartPairingCleanup(st.GetContext()) + + // Generate random CSRF key once at startup for security + csrfKey := make([]byte, 32) + if _, err := rand.Read(csrfKey); err != nil { + log.Fatal().Err(err).Msg("failed to generate CSRF key") + } + + // Security headers middleware - protect against common attacks + secureMiddleware := secure.New(secure.Options{ + FrameDeny: true, // X-Frame-Options: DENY + ContentTypeNosniff: true, // X-Content-Type-Options: nosniff + BrowserXssFilter: true, // X-XSS-Protection: 1; mode=block + ContentSecurityPolicy: "default-src 'self'; connect-src 'self' ws: wss:; " + + "img-src 'self' data:; style-src 'self' 'unsafe-inline'", + ReferrerPolicy: "strict-origin-when-cross-origin", + }) + r.Use(secureMiddleware.Handler) + r.Use(middleware.Recoverer) r.Use(middleware.NoCache) r.Use(middleware.Timeout(config.APIRequestTimeout)) r.Use(cors.Handler(cors.Options{ - AllowedOrigins: dynamicAllowedOrigins, - AllowedMethods: []string{"GET", "POST", "OPTIONS"}, - AllowedHeaders: []string{"Accept", "Content-Type"}, - ExposedHeaders: []string{}, + AllowedOrigins: dynamicAllowedOrigins, + AllowedMethods: []string{"GET", "POST", "OPTIONS"}, + AllowedHeaders: []string{"Accept", "Content-Type", "X-CSRF-Token"}, + ExposedHeaders: []string{"X-CSRF-Token"}, + AllowCredentials: true, // Required for CSRF protection })) if strings.HasSuffix(config.AppVersion, "-dev") { @@ -943,9 +1045,39 @@ func Start( } go broadcastNotifications(st, session, notifications) - // Pairing endpoints (no authentication required) - r.Post("/api/pair/initiate", handlePairingInitiate(db)) - r.Post("/api/pair/complete", handlePairingComplete(db)) + // Pairing endpoints with CSRF protection and enhanced rate limiting + r.Route("/api/pair", func(r chi.Router) { + + // CSRF protection for pairing endpoints + r.Use(csrf.Protect( + csrfKey, + csrf.Secure(false), // Set to true when using HTTPS + csrf.HttpOnly(true), + csrf.SameSite(csrf.SameSiteStrictMode), + )) + + // Enhanced rate limiting for pairing (more restrictive) + pairingLimiter := apimiddleware.NewIPRateLimiter() + pairingLimiter.StartCleanup(st.GetContext()) + r.Use(apimiddleware.HTTPRateLimitMiddleware(pairingLimiter)) + + r.Post("/initiate", handlePairingInitiate(db)) + r.Post("/complete", handlePairingComplete(db)) + + // CSRF token endpoint for clients + r.Get("/csrf-token", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + token := csrf.Token(r) + + response := CSRFTokenResponse{ + CSRFToken: token, + } + + if err := json.NewEncoder(w).Encode(response); err != nil { + log.Error().Err(err).Msg("Failed to write CSRF token response") + } + }) + }) // Protected API routes with authentication middleware r.Route("/api", func(r chi.Router) { diff --git a/pkg/config/config.go b/pkg/config/config.go index 8bd6ab319..b1baca1f2 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -359,6 +359,12 @@ func (c *Instance) AllowedOrigins() []string { return c.vals.Service.AllowedOrigins } +func (c *Instance) DeviceID() string { + c.mu.RLock() + defer c.mu.RUnlock() + return c.vals.Service.DeviceID +} + func (c *Instance) IsExecuteAllowed(s string) bool { c.mu.RLock() defer c.mu.RUnlock() diff --git a/pkg/database/userdb/clients.go b/pkg/database/userdb/clients.go index 362a6fd7c..7d3472ec4 100644 --- a/pkg/database/userdb/clients.go +++ b/pkg/database/userdb/clients.go @@ -22,6 +22,7 @@ package userdb import ( "context" "crypto/sha256" + "crypto/subtle" "encoding/hex" "encoding/json" "errors" @@ -87,43 +88,39 @@ func (db *UserDB) GetClientByAuthToken(authToken string) (*database.Client, erro return nil, ErrNullSQL } - authTokenHash := hashAuthToken(authToken) - - query := ` - SELECT client_id, client_name, auth_token_hash, shared_secret, current_seq, - seq_window, nonce_cache, created_at, last_seen - FROM clients - WHERE auth_token_hash = ? - ` + return db.getClientByAuthTokenConstantTime(authToken) +} - var client database.Client - var nonceCacheJSON string - var createdAt, lastSeen int64 +func (db *UserDB) getClientByAuthTokenConstantTime(authToken string) (*database.Client, error) { + targetHash := hashAuthToken(authToken) + targetHashBytes := []byte(targetHash) - err := db.sql.QueryRowContext(context.Background(), query, authTokenHash).Scan( - &client.ClientID, - &client.ClientName, - &client.AuthTokenHash, - &client.SharedSecret, - &client.CurrentSeq, - &client.SeqWindow, - &nonceCacheJSON, - &createdAt, - &lastSeen, - ) + // Get all clients to prevent timing attacks through database query optimization + clients, err := db.GetAllClients() if err != nil { - return nil, fmt.Errorf("client not found: %w", err) + return nil, fmt.Errorf("failed to get clients: %w", err) } - client.CreatedAt = time.Unix(createdAt, 0) - client.LastSeen = time.Unix(lastSeen, 0) + var foundClient *database.Client + // Use constant-time comparison for all clients + for i := range clients { + client := &clients[i] + clientHashBytes := []byte(client.AuthTokenHash) + + // Ensure both hashes are same length to prevent timing attacks + if len(targetHashBytes) == len(clientHashBytes) { + if subtle.ConstantTimeCompare(targetHashBytes, clientHashBytes) == 1 { + foundClient = client + // Don't break - continue checking all clients for constant time + } + } + } - err = json.Unmarshal([]byte(nonceCacheJSON), &client.NonceCache) - if err != nil { - client.NonceCache = make([]string, 0) // Fallback to empty cache + if foundClient == nil { + return nil, fmt.Errorf("client not found") } - return &client, nil + return foundClient, nil } func (db *UserDB) GetClientByID(clientID string) (*database.Client, error) { @@ -278,3 +275,4 @@ func hashAuthToken(authToken string) string { hash := sha256.Sum256([]byte(authToken)) return hex.EncodeToString(hash[:]) } + From 2b6b7722e68d89f4968f44c8f37eea0b89462676 Mon Sep 17 00:00:00 2001 From: Callan Barrett Date: Sat, 13 Sep 2025 10:28:54 +0800 Subject: [PATCH 4/5] lints --- pkg/api/middleware/auth.go | 32 ++++++++++++----- pkg/api/middleware/auth_test.go | 63 ++++++++++++++++++--------------- pkg/api/pairing.go | 9 ++--- pkg/api/server.go | 5 ++- pkg/database/userdb/clients.go | 5 ++- 5 files changed, 67 insertions(+), 47 deletions(-) diff --git a/pkg/api/middleware/auth.go b/pkg/api/middleware/auth.go index 7518c0e02..7c98f41be 100644 --- a/pkg/api/middleware/auth.go +++ b/pkg/api/middleware/auth.go @@ -54,12 +54,16 @@ type ClientMutexManager struct { mutexes sync.Map // map[string]*clientMutex } -type clientMutex struct { +// ClientMutex represents a per-client mutex for thread-safe operations +type ClientMutex struct { lastUsed time.Time clientID string mu sync.Mutex } +// Legacy alias for backward compatibility +type clientMutex = ClientMutex + var globalClientMutexManager = &ClientMutexManager{} type EncryptedRequest struct { @@ -353,23 +357,31 @@ func updateClientSequenceAndNonce(client *database.Client, seq uint64, nonce str } // getClientMutex retrieves or creates a mutex for the given client ID -func (cm *ClientMutexManager) getClientMutex(clientID string) *clientMutex { +func (cm *ClientMutexManager) getClientMutex(clientID string) *ClientMutex { // Try to load existing mutex if value, exists := cm.mutexes.Load(clientID); exists { - mutex := value.(*clientMutex) + mutex, ok := value.(*clientMutex) + if !ok { + log.Error().Str("client_id", clientID).Msg("invalid mutex type in cache") + return nil + } mutex.lastUsed = time.Now() return mutex } // Create new mutex - newMutex := &clientMutex{ + newMutex := &ClientMutex{ lastUsed: time.Now(), clientID: clientID, } // Store and return the mutex (LoadOrStore handles race conditions) actual, _ := cm.mutexes.LoadOrStore(clientID, newMutex) - actualMutex := actual.(*clientMutex) + actualMutex, ok := actual.(*clientMutex) + if !ok { + log.Error().Str("client_id", clientID).Msg("invalid mutex type after LoadOrStore") + return nil + } actualMutex.lastUsed = time.Now() return actualMutex } @@ -388,8 +400,12 @@ func (cm *ClientMutexManager) lockClient(clientID string) func() { // cleanup removes unused mutexes to prevent memory leaks func (cm *ClientMutexManager) cleanup() { now := time.Now() - cm.mutexes.Range(func(key, value interface{}) bool { - mutex := value.(*clientMutex) + cm.mutexes.Range(func(key, value any) bool { + mutex, ok := value.(*clientMutex) + if !ok { + log.Error().Interface("key", key).Msg("invalid mutex type in cleanup") + return true + } if now.Sub(mutex.lastUsed) > MutexMaxIdle { cm.mutexes.Delete(key) log.Debug().Str("client_id", mutex.clientID).Msg("cleaned up unused client mutex") @@ -417,7 +433,7 @@ func (cm *ClientMutexManager) StartCleanupRoutine(ctx context.Context) { } // GetClientMutex is a convenience function to get a mutex for a client -func GetClientMutex(clientID string) *clientMutex { +func GetClientMutex(clientID string) *ClientMutex { return globalClientMutexManager.getClientMutex(clientID) } diff --git a/pkg/api/middleware/auth_test.go b/pkg/api/middleware/auth_test.go index a621afb86..837ecc128 100644 --- a/pkg/api/middleware/auth_test.go +++ b/pkg/api/middleware/auth_test.go @@ -104,6 +104,7 @@ func TestAuthMiddleware_RemoteRequiresAuth(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() // Send regular JSON without proper auth fields - should fail at auth token lookup req := httptest.NewRequest(http.MethodPost, "/api/test", bytes.NewReader([]byte(`{"test": "data"}`))) req.RemoteAddr = tt.remoteAddr @@ -408,6 +409,7 @@ func TestAuthMiddleware_InvalidRequests(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() req := httptest.NewRequest(http.MethodPost, "/api/test", bytes.NewReader([]byte(tt.body))) req.RemoteAddr = "192.168.1.100:5000" // Remote address to trigger auth @@ -534,8 +536,10 @@ func TestClientMutexManager_ConcurrentAccess(t *testing.T) { defer func() { done <- struct{}{} }() mutex := dm.getClientMutex(deviceID) - require.NotNil(t, mutex) - assert.Equal(t, deviceID, mutex.clientID) + assert.NotNil(t, mutex) + if mutex != nil { + assert.Equal(t, deviceID, mutex.clientID) + } }() } @@ -548,14 +552,15 @@ func TestClientMutexManager_ConcurrentAccess(t *testing.T) { value, exists := dm.mutexes.Load(deviceID) assert.True(t, exists, "Mutex should exist") - mutex := value.(*clientMutex) + mutex, ok := value.(*clientMutex) + require.True(t, ok) assert.Equal(t, deviceID, mutex.clientID) } // TestShiftWindowRight verifies the bit manipulation for the sliding window func TestShiftWindowRight(t *testing.T) { t.Parallel() - + tests := []struct { name string input []byte @@ -589,9 +594,9 @@ func TestShiftWindowRight(t *testing.T) { // Make a copy to avoid modifying the test input window := make([]byte, len(tt.input)) copy(window, tt.input) - + shiftWindowRight(window) - + assert.Equal(t, tt.expected, window, "bit shift should match expected result") }) } @@ -600,36 +605,36 @@ func TestShiftWindowRight(t *testing.T) { // TestUpdateClientSequenceAndNonce verifies the sequence window updates correctly func TestUpdateClientSequenceAndNonce(t *testing.T) { t.Parallel() - + tests := []struct { - name string - initialSeq uint64 + name string + nonce string initialWindow []byte - newSeq uint64 - nonce string - expectedSeq uint64 - checkBitSet bool - bitPosition uint64 + initialSeq uint64 + newSeq uint64 + expectedSeq uint64 + bitPosition uint64 + checkBitSet bool }{ { - name: "increment sequence", - initialSeq: 5, + name: "increment sequence", + initialSeq: 5, initialWindow: []byte{0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, - newSeq: 6, - nonce: "nonce6", - expectedSeq: 6, - checkBitSet: true, - bitPosition: 0, // Latest sequence should be at position 0 + newSeq: 6, + nonce: "nonce6", + expectedSeq: 6, + checkBitSet: true, + bitPosition: 0, // Latest sequence should be at position 0 }, { - name: "large jump forward", - initialSeq: 5, + name: "large jump forward", + initialSeq: 5, initialWindow: []byte{0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, - newSeq: 100, - nonce: "nonce100", - expectedSeq: 100, - checkBitSet: true, - bitPosition: 0, // Latest sequence should be at position 0 after window clear + newSeq: 100, + nonce: "nonce100", + expectedSeq: 100, + checkBitSet: true, + bitPosition: 0, // Latest sequence should be at position 0 after window clear }, } @@ -648,7 +653,7 @@ func TestUpdateClientSequenceAndNonce(t *testing.T) { assert.Equal(t, tt.expectedSeq, client.CurrentSeq, "sequence should be updated") assert.Contains(t, client.NonceCache, tt.nonce, "nonce should be added to cache") - + if tt.checkBitSet { // Check that the bit at position 0 is set (latest sequence) assert.NotZero(t, client.SeqWindow[0]&1, "bit 0 should be set for latest sequence") diff --git a/pkg/api/pairing.go b/pkg/api/pairing.go index 78697f983..335c604aa 100644 --- a/pkg/api/pairing.go +++ b/pkg/api/pairing.go @@ -25,6 +25,7 @@ import ( "crypto/sha256" "encoding/hex" "encoding/json" + "errors" "fmt" "net/http" "regexp" @@ -88,16 +89,16 @@ var clientNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`) func validateClientName(name string) error { name = strings.TrimSpace(name) - if len(name) == 0 { - return fmt.Errorf("client name cannot be empty") + if name == "" { + return errors.New("client name cannot be empty") } if len(name) > 100 { - return fmt.Errorf("client name too long (max 100 characters)") + return errors.New("client name too long (max 100 characters)") } if !clientNameRegex.MatchString(name) { - return fmt.Errorf("client name contains invalid characters (only letters, numbers, underscore, and dash allowed)") + return errors.New("client name contains invalid characters (only letters, numbers, underscore, and dash)") } return nil diff --git a/pkg/api/server.go b/pkg/api/server.go index 7b1951f6c..05caffa08 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -708,12 +708,12 @@ func isSessionExpired(session *melody.Session) bool { if !exists { return true // No auth time means expired } - + timestamp, ok := authTime.(time.Time) if !ok { return true // Invalid timestamp means expired } - + return time.Since(timestamp) > sessionTimeout } @@ -1047,7 +1047,6 @@ func Start( // Pairing endpoints with CSRF protection and enhanced rate limiting r.Route("/api/pair", func(r chi.Router) { - // CSRF protection for pairing endpoints r.Use(csrf.Protect( csrfKey, diff --git a/pkg/database/userdb/clients.go b/pkg/database/userdb/clients.go index 7d3472ec4..0ac1fc174 100644 --- a/pkg/database/userdb/clients.go +++ b/pkg/database/userdb/clients.go @@ -106,7 +106,7 @@ func (db *UserDB) getClientByAuthTokenConstantTime(authToken string) (*database. for i := range clients { client := &clients[i] clientHashBytes := []byte(client.AuthTokenHash) - + // Ensure both hashes are same length to prevent timing attacks if len(targetHashBytes) == len(clientHashBytes) { if subtle.ConstantTimeCompare(targetHashBytes, clientHashBytes) == 1 { @@ -117,7 +117,7 @@ func (db *UserDB) getClientByAuthTokenConstantTime(authToken string) (*database. } if foundClient == nil { - return nil, fmt.Errorf("client not found") + return nil, errors.New("client not found") } return foundClient, nil @@ -275,4 +275,3 @@ func hashAuthToken(authToken string) string { hash := sha256.Sum256([]byte(authToken)) return hex.EncodeToString(hash[:]) } - From 12a5ad2bc36fd65913237cc2802cde1ffb71a93d Mon Sep 17 00:00:00 2001 From: Callan Barrett Date: Sat, 13 Sep 2025 12:50:45 +0800 Subject: [PATCH 5/5] new wireguard based replay guard --- pkg/api/middleware/auth.go | 121 +++------- pkg/api/middleware/auth_test.go | 277 +++++++++++------------ pkg/api/middleware/replay_filter.go | 93 ++++++++ pkg/api/middleware/replay_filter_test.go | 128 +++++++++++ pkg/api/middleware/replay_protection.go | 113 +++++++++ pkg/api/server.go | 197 ++++++++-------- pkg/database/userdb/clients.go | 55 ++--- 7 files changed, 609 insertions(+), 375 deletions(-) create mode 100644 pkg/api/middleware/replay_filter.go create mode 100644 pkg/api/middleware/replay_filter_test.go create mode 100644 pkg/api/middleware/replay_protection.go diff --git a/pkg/api/middleware/auth.go b/pkg/api/middleware/auth.go index 7c98f41be..bcafa3b39 100644 --- a/pkg/api/middleware/auth.go +++ b/pkg/api/middleware/auth.go @@ -31,7 +31,6 @@ import ( "io" "net" "net/http" - "slices" "sync" "time" @@ -40,8 +39,6 @@ import ( ) const ( - SequenceWindow = 64 // Size of sliding window for sequence numbers - NonceCacheSize = 100 // Maximum number of cached nonces MutexCleanupInterval = 10 * time.Minute // Cleanup unused mutexes every 10 minutes MutexMaxIdle = 30 * time.Minute // Remove mutexes unused for 30 minutes ) @@ -139,8 +136,21 @@ func AuthMiddleware(db *database.Database) func(http.Handler) http.Handler { return } - // Decrypt payload with initial client data - decryptedPayload, err := DecryptPayload(encReq.Encrypted, encReq.IV, initialClient.SharedSecret) + // Acquire client lock BEFORE any decryption or validation + // to prevent race conditions between concurrent requests + unlockClient := LockClient(initialClient.ClientID) + defer unlockClient() + + // Re-fetch client state under lock to get latest sequence/nonce state + client, err := db.UserDB.GetClientByAuthToken(encReq.AuthToken) + if err != nil { + log.Error().Err(err).Msg("failed to re-fetch client under lock") + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + // Decrypt payload with locked client data + decryptedPayload, err := DecryptPayload(encReq.Encrypted, encReq.IV, client.SharedSecret) if err != nil { log.Error().Err(err).Msg("failed to decrypt payload") http.Error(w, "Decryption failed", http.StatusBadRequest) @@ -155,21 +165,11 @@ func AuthMiddleware(db *database.Database) func(http.Handler) http.Handler { return } - // CRITICAL: Acquire client lock BEFORE any sequence/nonce validation - // to prevent race conditions between concurrent requests - unlockClient := LockClient(initialClient.ClientID) - defer unlockClient() - - // Re-fetch client state under lock to get latest sequence/nonce state - client, err := db.UserDB.GetClientByAuthToken(encReq.AuthToken) - if err != nil { - log.Error().Err(err).Msg("failed to re-fetch client under lock") - http.Error(w, "Internal server error", http.StatusInternalServerError) - return - } + // Create replay protector from client state + replayProtector := NewReplayProtector(client) // Validate sequence number and nonce with locked client state - if !ValidateSequenceAndNonce(client, payload.Seq, payload.Nonce) { + if !replayProtector.ValidateSequenceAndNonce(payload.Seq, payload.Nonce) { log.Warn(). Str("client_id", client.ClientID). Uint64("seq", payload.Seq). @@ -181,8 +181,14 @@ func AuthMiddleware(db *database.Database) func(http.Handler) http.Handler { return } - // Update client state (sequence and nonce cache) - updateClientSequenceAndNonce(client, payload.Seq, payload.Nonce) + // Update replay protector state + replayProtector.UpdateSequenceAndNonce(payload.Seq, payload.Nonce) + + // Get updated state for database storage + currentSeq, seqWindow, nonceCache := replayProtector.GetStateForDatabase() + client.CurrentSeq = currentSeq + client.SeqWindow = seqWindow + client.NonceCache = nonceCache // Save to database (still under lock) if updateErr := db.UserDB.UpdateClientSequence( @@ -290,72 +296,6 @@ func EncryptPayload(data, key []byte) (encrypted, iv string, err error) { base64.StdEncoding.EncodeToString(ivBytes), nil } -func ValidateSequenceAndNonce(client *database.Client, seq uint64, nonce string) bool { - // Check if nonce was recently used (replay protection) - if slices.Contains(client.NonceCache, nonce) { - return false - } - - // Validate sequence number with sliding window - if seq <= client.CurrentSeq { - // Check if sequence is within acceptable window - diff := client.CurrentSeq - seq - if diff >= SequenceWindow { - return false // Too old - } - - // Check if this sequence was already processed (using seq_window bitmap) - windowPos := diff % SequenceWindow - bytePos := windowPos / 8 - bitPos := windowPos % 8 - - if bytePos < uint64(len(client.SeqWindow)) { - if (client.SeqWindow[bytePos] & (1 << bitPos)) != 0 { - return false // Already processed - } - } - } - - return true -} - -func updateClientSequenceAndNonce(client *database.Client, seq uint64, nonce string) { - // Update nonce cache (keep last NonceCacheSize nonces) - client.NonceCache = append(client.NonceCache, nonce) - if len(client.NonceCache) > NonceCacheSize { - client.NonceCache = client.NonceCache[1:] // Remove oldest - } - - // Update sequence window - if seq > client.CurrentSeq { - // New highest sequence - shift window - shift := seq - client.CurrentSeq - if shift >= SequenceWindow { - // Clear entire window - client.SeqWindow = make([]byte, 8) - } else { - // Shift window right - for range shift { - shiftWindowRight(client.SeqWindow) - } - } - client.CurrentSeq = seq - - // Mark current sequence as processed (position 0 in window) - client.SeqWindow[0] |= 1 - } else { - // Mark this sequence as processed in the window - diff := client.CurrentSeq - seq - windowPos := diff % SequenceWindow - bytePos := windowPos / 8 - bitPos := windowPos % 8 - - if bytePos < uint64(len(client.SeqWindow)) { - client.SeqWindow[bytePos] |= (1 << bitPos) - } - } -} - // getClientMutex retrieves or creates a mutex for the given client ID func (cm *ClientMutexManager) getClientMutex(clientID string) *ClientMutex { // Try to load existing mutex @@ -447,15 +387,6 @@ func StartGlobalMutexCleanup(ctx context.Context) { globalClientMutexManager.StartCleanupRoutine(ctx) } -func shiftWindowRight(window []byte) { - carry := byte(0) - for i := len(window) - 1; i >= 0; i-- { - newCarry := (window[i] & 1) << 7 - window[i] = (window[i] >> 1) | carry - carry = newCarry - } -} - func IsAuthenticatedConnection(r *http.Request) bool { return !isLocalhost(r.RemoteAddr) } diff --git a/pkg/api/middleware/auth_test.go b/pkg/api/middleware/auth_test.go index 837ecc128..96325d943 100644 --- a/pkg/api/middleware/auth_test.go +++ b/pkg/api/middleware/auth_test.go @@ -36,6 +36,7 @@ import ( "github.com/stretchr/testify/require" ) +//nolint:paralleltest,tparallel // Security tests require deterministic mock validation func TestAuthMiddleware_LocalhostBypass(t *testing.T) { t.Parallel() // Setup @@ -64,7 +65,6 @@ func TestAuthMiddleware_LocalhostBypass(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - t.Parallel() req := httptest.NewRequest(http.MethodPost, "/api/test", bytes.NewReader([]byte(`{"test": "data"}`))) req.RemoteAddr = tt.remoteAddr @@ -77,6 +77,7 @@ func TestAuthMiddleware_LocalhostBypass(t *testing.T) { } } +//nolint:paralleltest,tparallel // Security tests require deterministic mock validation func TestAuthMiddleware_RemoteRequiresAuth(t *testing.T) { t.Parallel() // Setup @@ -104,7 +105,6 @@ func TestAuthMiddleware_RemoteRequiresAuth(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - t.Parallel() // Send regular JSON without proper auth fields - should fail at auth token lookup req := httptest.NewRequest(http.MethodPost, "/api/test", bytes.NewReader([]byte(`{"test": "data"}`))) req.RemoteAddr = tt.remoteAddr @@ -203,7 +203,8 @@ func TestAuthMiddleware_EncryptedRequest(t *testing.T) { userDB.AssertExpectations(t) } -func TestValidateSequenceAndNonce_ReplayProtection(t *testing.T) { +//nolint:paralleltest,tparallel // Security tests require deterministic mock validation +func TestReplayProtector_ReplayProtection(t *testing.T) { t.Parallel() tests := []struct { name string @@ -218,7 +219,7 @@ func TestValidateSequenceAndNonce_ReplayProtection(t *testing.T) { { name: "first message", currentSeq: 0, - seqWindow: make([]byte, 8), + seqWindow: make([]byte, 8+128*8), // Ring buffer: 8 bytes + 128 blocks * 8 bytes nonceCache: []string{}, newSeq: 1, newNonce: "nonce1", @@ -228,7 +229,7 @@ func TestValidateSequenceAndNonce_ReplayProtection(t *testing.T) { { name: "sequence increment", currentSeq: 5, - seqWindow: []byte{0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + seqWindow: make([]byte, 8+128*8), nonceCache: []string{"old-nonce"}, newSeq: 6, newNonce: "nonce6", @@ -238,7 +239,7 @@ func TestValidateSequenceAndNonce_ReplayProtection(t *testing.T) { { name: "duplicate nonce", currentSeq: 5, - seqWindow: []byte{0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + seqWindow: make([]byte, 8+128*8), nonceCache: []string{"duplicate-nonce"}, newSeq: 6, newNonce: "duplicate-nonce", @@ -246,52 +247,32 @@ func TestValidateSequenceAndNonce_ReplayProtection(t *testing.T) { description: "duplicate nonce should be rejected", }, { - name: "old sequence out of window", - currentSeq: 100, - seqWindow: make([]byte, 8), + name: "old sequence far out of window", + currentSeq: 50000, + seqWindow: make([]byte, 8+128*8), nonceCache: []string{}, - newSeq: 10, // More than 64 behind - newNonce: "nonce10", + newSeq: 100, // More than 8000+ behind (outside WireGuard window) + newNonce: "nonce100", expectedResult: false, description: "sequence too far behind should be rejected", }, { - name: "sequence within window", - currentSeq: 10, - seqWindow: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + name: "sequence within large window", + currentSeq: 1000, + seqWindow: make([]byte, 8+128*8), nonceCache: []string{}, - newSeq: 8, // 2 behind, within window - newNonce: "nonce8", + newSeq: 950, // Within large window + newNonce: "nonce950", expectedResult: true, description: "sequence within sliding window should pass", }, { - name: "sequence at window boundary", - currentSeq: 64, - seqWindow: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, - nonceCache: []string{}, - newSeq: 1, // Exactly at window boundary (64 behind) - newNonce: "nonce1", - expectedResult: false, - description: "sequence exactly at window boundary should be rejected", - }, - { - name: "sequence already processed", - currentSeq: 10, - seqWindow: []byte{0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, // Bit 2 set (seq 8) - nonceCache: []string{}, - newSeq: 8, // Already processed - newNonce: "nonce8", - expectedResult: false, - description: "already processed sequence should be rejected", - }, - { - name: "large sequence jump", + name: "large sequence jump forward", currentSeq: 5, - seqWindow: []byte{0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + seqWindow: make([]byte, 8+128*8), nonceCache: []string{}, - newSeq: 100, // Large jump forward - newNonce: "nonce100", + newSeq: 1000, // Large jump forward + newNonce: "nonce1000", expectedResult: true, description: "large sequence jump should be accepted", }, @@ -299,15 +280,15 @@ func TestValidateSequenceAndNonce_ReplayProtection(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - t.Parallel() - device := &database.Client{ + client := &database.Client{ ClientID: "test-device", CurrentSeq: tt.currentSeq, SeqWindow: tt.seqWindow, NonceCache: tt.nonceCache, } - result := ValidateSequenceAndNonce(device, tt.newSeq, tt.newNonce) + replayProtector := NewReplayProtector(client) + result := replayProtector.ValidateSequenceAndNonce(tt.newSeq, tt.newNonce) assert.Equal(t, tt.expectedResult, result, tt.description) }) } @@ -346,6 +327,7 @@ func TestEncryptDecryptPayload_WrongKey(t *testing.T) { assert.Contains(t, err.Error(), "decryption failed") } +//nolint:paralleltest,tparallel // Security tests require deterministic execution order func TestIsLocalhost(t *testing.T) { t.Parallel() tests := []struct { @@ -365,18 +347,18 @@ func TestIsLocalhost(t *testing.T) { for _, tt := range tests { t.Run(tt.addr, func(t *testing.T) { - t.Parallel() result := isLocalhost(tt.addr) assert.Equal(t, tt.expected, result) }) } } +//nolint:paralleltest,tparallel // Security tests require deterministic mock validation func TestAuthMiddleware_InvalidRequests(t *testing.T) { t.Parallel() userDB := helpers.NewMockUserDBI() // Mock empty auth token lookup - expect it to be called once for the missing auth token test - userDB.On("GetClientByAuthToken", "").Return((*database.Client)(nil), assert.AnError) + userDB.On("GetClientByAuthToken", "").Return((*database.Client)(nil), assert.AnError).Once() db := &database.Database{UserDB: userDB} @@ -409,7 +391,6 @@ func TestAuthMiddleware_InvalidRequests(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - t.Parallel() req := httptest.NewRequest(http.MethodPost, "/api/test", bytes.NewReader([]byte(tt.body))) req.RemoteAddr = "192.168.1.100:5000" // Remote address to trigger auth @@ -487,6 +468,105 @@ func TestAuthMiddleware_ConcurrentRequests(t *testing.T) { assert.Equal(t, int32(0), atomic.LoadInt32(&lockAcquired), "All locks should be released") } +// TestAuthMiddleware_ConcurrentAuthentication tests actual concurrent authentication +// requests to ensure the race condition fix prevents replay attacks +func TestAuthMiddleware_ConcurrentAuthentication(t *testing.T) { + t.Parallel() + + // Setup + userDB := helpers.NewMockUserDBI() + db := &database.Database{UserDB: userDB} + + // Create a test client + testClient := &database.Client{ + ClientID: "test-client-concurrent", + ClientName: "Test Client", + AuthTokenHash: "test-hash", + SharedSecret: []byte("test-secret-32-bytes-long-key!!!"), + CurrentSeq: 5, + SeqWindow: make([]byte, 8), + NonceCache: []string{}, + CreatedAt: time.Now(), + LastSeen: time.Now(), + } + + // Mock database calls + userDB.On("GetClientByAuthToken", "test-token").Return(testClient, nil) + userDB.On("UpdateClientSequence", + testClient.ClientID, + mock.AnythingOfType("uint64"), + mock.AnythingOfType("[]uint8"), + mock.AnythingOfType("[]string")).Return(nil) + + // Create test handler that tracks successful authentications + var successCount int32 + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + atomic.AddInt32(&successCount, 1) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("success")) + }) + + middleware := AuthMiddleware(db) + wrappedHandler := middleware(handler) + + // Create encrypted request with same sequence number (should cause replay detection) + payload := map[string]any{ + "jsonrpc": "2.0", + "method": "test.method", + "id": 1, + "nonce": "test-nonce-concurrent", + "seq": uint64(6), // Same sequence for all requests + } + payloadJSON, _ := json.Marshal(payload) + + encrypted, iv, err := EncryptPayload(payloadJSON, testClient.SharedSecret) + require.NoError(t, err) + + encRequest := map[string]string{ + "encrypted": encrypted, + "iv": iv, + "authToken": "test-token", + } + requestBody, _ := json.Marshal(encRequest) + + // Run concurrent requests + const numRequests = 10 + done := make(chan int, numRequests) + + for i := range numRequests { + go func(_ int) { + req := httptest.NewRequest(http.MethodPost, "/api/test", bytes.NewReader(requestBody)) + req.RemoteAddr = "192.168.1.100:12345" // Remote address to trigger auth + + w := httptest.NewRecorder() + wrappedHandler.ServeHTTP(w, req) + + done <- w.Code + }(i) + } + + // Collect results + statusCodes := make([]int, 0, numRequests) + for range numRequests { + statusCodes = append(statusCodes, <-done) + } + + // Only ONE request should succeed (200), others should fail with replay detection (400) + successStatusCount := 0 + for _, code := range statusCodes { + if code == http.StatusOK { + successStatusCount++ + } else { + assert.Equal(t, http.StatusBadRequest, code, "Failed requests should return 400 for replay detection") + } + } + + assert.Equal(t, 1, successStatusCount, "Only one concurrent request should succeed") + assert.Equal(t, int32(1), atomic.LoadInt32(&successCount), "Handler should only be called once") + + userDB.AssertExpectations(t) +} + // TestClientMutexManager_Cleanup verifies mutex cleanup works correctly func TestClientMutexManager_Cleanup(t *testing.T) { t.Parallel() @@ -556,108 +636,3 @@ func TestClientMutexManager_ConcurrentAccess(t *testing.T) { require.True(t, ok) assert.Equal(t, deviceID, mutex.clientID) } - -// TestShiftWindowRight verifies the bit manipulation for the sliding window -func TestShiftWindowRight(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - input []byte - expected []byte - }{ - { - name: "shift zeros", - input: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, - expected: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, - }, - { - name: "shift ones", - input: []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, - expected: []byte{0x7F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, - }, - { - name: "shift single bit", - input: []byte{0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, - expected: []byte{0x00, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, - }, - { - name: "shift pattern", - input: []byte{0xAA, 0x55, 0xAA, 0x55, 0xAA, 0x55, 0xAA, 0x55}, - expected: []byte{0x55, 0x2A, 0xD5, 0x2A, 0xD5, 0x2A, 0xD5, 0x2A}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - // Make a copy to avoid modifying the test input - window := make([]byte, len(tt.input)) - copy(window, tt.input) - - shiftWindowRight(window) - - assert.Equal(t, tt.expected, window, "bit shift should match expected result") - }) - } -} - -// TestUpdateClientSequenceAndNonce verifies the sequence window updates correctly -func TestUpdateClientSequenceAndNonce(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - nonce string - initialWindow []byte - initialSeq uint64 - newSeq uint64 - expectedSeq uint64 - bitPosition uint64 - checkBitSet bool - }{ - { - name: "increment sequence", - initialSeq: 5, - initialWindow: []byte{0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, - newSeq: 6, - nonce: "nonce6", - expectedSeq: 6, - checkBitSet: true, - bitPosition: 0, // Latest sequence should be at position 0 - }, - { - name: "large jump forward", - initialSeq: 5, - initialWindow: []byte{0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, - newSeq: 100, - nonce: "nonce100", - expectedSeq: 100, - checkBitSet: true, - bitPosition: 0, // Latest sequence should be at position 0 after window clear - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - client := &database.Client{ - ClientID: "test-device", - CurrentSeq: tt.initialSeq, - SeqWindow: make([]byte, len(tt.initialWindow)), - NonceCache: []string{}, - } - copy(client.SeqWindow, tt.initialWindow) - - updateClientSequenceAndNonce(client, tt.newSeq, tt.nonce) - - assert.Equal(t, tt.expectedSeq, client.CurrentSeq, "sequence should be updated") - assert.Contains(t, client.NonceCache, tt.nonce, "nonce should be added to cache") - - if tt.checkBitSet { - // Check that the bit at position 0 is set (latest sequence) - assert.NotZero(t, client.SeqWindow[0]&1, "bit 0 should be set for latest sequence") - } - }) - } -} diff --git a/pkg/api/middleware/replay_filter.go b/pkg/api/middleware/replay_filter.go new file mode 100644 index 000000000..7bf923dec --- /dev/null +++ b/pkg/api/middleware/replay_filter.go @@ -0,0 +1,93 @@ +// Zaparoo Core +// Copyright (c) 2025 The Zaparoo Project Contributors. +// SPDX-License-Identifier: GPL-3.0-or-later +// +// This file is part of Zaparoo Core. +// +// Zaparoo Core is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Zaparoo Core is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Zaparoo Core. If not, see . + +package middleware + +/* +ATTRIBUTION NOTICE: + +This file contains code derived from WireGuard's anti-replay filter implementation. + +Original work: Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. +Original license: MIT License +Source: https://github.com/WireGuard/wireguard-go/blob/master/replay/replay.go + +The derived portions remain under their original MIT license as required. +The combined work is distributed under GPL-3.0-or-later as indicated in the file header. +*/ + +type block uint64 + +const ( + blockBitLog = 6 + blockBits = 1 << blockBitLog + ringBlocks = 1 << 7 + windowSize = (ringBlocks - 1) * blockBits + blockMask = ringBlocks - 1 + bitMask = blockBits - 1 +) + +// Filter implements an anti-replay mechanism as specified in RFC 6479. +// It rejects replayed messages by checking if message counter value is within +// a sliding window of previously received messages. +type Filter struct { + last uint64 + ring [ringBlocks]block +} + +// Reset resets the filter to an empty state. +func (f *Filter) Reset() { + f.last = 0 + for i := range f.ring { + f.ring[i] = 0 + } +} + +// ValidateCounter checks if a given counter should be accepted. +// It automatically rejects counters >= the specified limit. +func (f *Filter) ValidateCounter(counter, limit uint64) bool { + if counter >= limit { + return false + } + + indexBlock := counter >> blockBitLog + + if counter > f.last { + // New highest counter - shift window + current := f.last >> blockBitLog + diff := indexBlock - current + if diff > ringBlocks { + diff = ringBlocks + } + for i := current + 1; i <= current+diff; i++ { + f.ring[i&blockMask] = 0 + } + f.last = counter + } else if f.last-counter > windowSize { + // Too old + return false + } + + indexBlock &= blockMask + indexBit := counter & bitMask + old := f.ring[indexBlock] + updated := old | 1<. + +package middleware + +import ( + "testing" +) + +const RejectAfterMessages = 1<<64 - 1<<13 - 1 + +func TestReplayFilter(t *testing.T) { + t.Parallel() + var filter Filter + const tLim = windowSize + 1 + testNumber := 0 + + testFunc := func(n uint64, expected bool) { + testNumber++ + if filter.ValidateCounter(n, RejectAfterMessages) != expected { + t.Fatal("Test", testNumber, "failed", n, expected) + } + } + + filter.Reset() + + testFunc(0, true) + testFunc(1, true) + testFunc(1, false) + testFunc(9, true) + testFunc(8, true) + testFunc(7, true) + testFunc(7, false) + testFunc(tLim, true) + testFunc(tLim-1, true) + testFunc(tLim-1, false) + testFunc(tLim-2, true) + testFunc(2, true) + testFunc(2, false) + testFunc(tLim+16, true) + testFunc(3, false) + testFunc(tLim+16, false) + testFunc(tLim*4, true) + testFunc(tLim*4-(tLim-1), true) + testFunc(10, false) + testFunc(tLim*4-tLim, false) + testFunc(tLim*4-(tLim+1), false) + testFunc(tLim*4-(tLim-2), true) + testFunc(tLim*4+1-tLim, false) + testFunc(0, false) + testFunc(RejectAfterMessages, false) + testFunc(RejectAfterMessages-1, true) + testFunc(RejectAfterMessages, false) + testFunc(RejectAfterMessages-1, false) + testFunc(RejectAfterMessages-2, true) + testFunc(RejectAfterMessages+1, false) + testFunc(RejectAfterMessages+2, false) + testFunc(RejectAfterMessages-2, false) + testFunc(RejectAfterMessages-3, true) + testFunc(0, false) + + t.Log("Bulk test 1") + filter.Reset() + testNumber = 0 + for i := uint64(1); i <= windowSize; i++ { + testFunc(i, true) + } + testFunc(0, true) + testFunc(0, false) + + t.Log("Bulk test 2") + filter.Reset() + testNumber = 0 + for i := uint64(2); i <= windowSize+1; i++ { + testFunc(i, true) + } + testFunc(1, true) + testFunc(0, false) + + t.Log("Bulk test 3") + filter.Reset() + testNumber = 0 + for i := uint64(windowSize + 1); i > 0; i-- { + testFunc(i, true) + } + + t.Log("Bulk test 4") + filter.Reset() + testNumber = 0 + for i := uint64(windowSize + 2); i > 1; i-- { + testFunc(i, true) + } + testFunc(0, false) + + t.Log("Bulk test 5") + filter.Reset() + testNumber = 0 + for i := uint64(windowSize); i > 0; i-- { + testFunc(i, true) + } + testFunc(windowSize+1, true) + testFunc(0, false) + + t.Log("Bulk test 6") + filter.Reset() + testNumber = 0 + for i := uint64(windowSize); i > 0; i-- { + testFunc(i, true) + } + testFunc(0, true) + testFunc(windowSize+1, true) +} diff --git a/pkg/api/middleware/replay_protection.go b/pkg/api/middleware/replay_protection.go new file mode 100644 index 000000000..d28146472 --- /dev/null +++ b/pkg/api/middleware/replay_protection.go @@ -0,0 +1,113 @@ +// Zaparoo Core +// Copyright (c) 2025 The Zaparoo Project Contributors. +// SPDX-License-Identifier: GPL-3.0-or-later +// +// This file is part of Zaparoo Core. +// +// Zaparoo Core is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Zaparoo Core is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Zaparoo Core. If not, see . + +package middleware + +import ( + "encoding/binary" + "slices" + + "github.com/ZaparooProject/zaparoo-core/v2/pkg/database" +) + +const ( + MaxSequenceNumber = 1<<64 - 1<<13 - 1 // Same as WireGuard's RejectAfterMessages + NonceCacheSize = 100 // Maximum number of cached nonces +) + +// ReplayProtector combines WireGuard's sequence number validation with nonce replay protection +type ReplayProtector struct { + nonceCache []string + filter Filter + lastSeq uint64 +} + +// NewReplayProtector creates a new replay protector from client state +func NewReplayProtector(client *database.Client) *ReplayProtector { + rp := &ReplayProtector{ + nonceCache: make([]string, len(client.NonceCache)), + lastSeq: client.CurrentSeq, + } + copy(rp.nonceCache, client.NonceCache) + + // Initialize filter from stored ring buffer state + if len(client.SeqWindow) >= 8+len(rp.filter.ring)*8 { + storedLastSeq := binary.LittleEndian.Uint64(client.SeqWindow[0:8]) + if storedLastSeq > 0 { + // Load stored state: first 8 bytes = last sequence, then 8 bytes per ring block + rp.filter.last = storedLastSeq + for i := range rp.filter.ring { + offset := 8 + i*8 + rp.filter.ring[i] = block(binary.LittleEndian.Uint64(client.SeqWindow[offset : offset+8])) + } + } else { + // Buffer exists but uninitialized - treat as new client + rp.filter.Reset() + if client.CurrentSeq > 0 { + rp.filter.ValidateCounter(client.CurrentSeq, MaxSequenceNumber) + } + } + } else { + // Buffer too small - initialize fresh + rp.filter.Reset() + if client.CurrentSeq > 0 { + rp.filter.ValidateCounter(client.CurrentSeq, MaxSequenceNumber) + } + } + + return rp +} + +// ValidateSequenceAndNonce validates both sequence number and nonce for replay protection +func (rp *ReplayProtector) ValidateSequenceAndNonce(seq uint64, nonce string) bool { + // Check nonce first (replay protection) + if slices.Contains(rp.nonceCache, nonce) { + return false + } + + // Then validate sequence number using WireGuard's algorithm + return rp.filter.ValidateCounter(seq, MaxSequenceNumber) +} + +// UpdateSequenceAndNonce updates the replay protector state after successful validation +func (rp *ReplayProtector) UpdateSequenceAndNonce(seq uint64, nonce string) { + // Update nonce cache (keep last NonceCacheSize nonces) + rp.nonceCache = append(rp.nonceCache, nonce) + if len(rp.nonceCache) > NonceCacheSize { + rp.nonceCache = rp.nonceCache[1:] // Remove oldest + } + + // Sequence is already updated by ValidateCounter if it was accepted + rp.lastSeq = max(rp.lastSeq, seq) +} + +// GetStateForDatabase returns the state that should be stored in the database +func (rp *ReplayProtector) GetStateForDatabase() (currentSeq uint64, seqWindow []byte, nonceCache []string) { + // Serialize ring buffer state + // Format: [8 bytes last seq][8 bytes per ring block] + seqWindow = make([]byte, 8+len(rp.filter.ring)*8) + binary.LittleEndian.PutUint64(seqWindow[0:8], rp.filter.last) + + for i, block := range rp.filter.ring[:] { + offset := 8 + i*8 + binary.LittleEndian.PutUint64(seqWindow[offset:offset+8], uint64(block)) + } + + return rp.filter.last, seqWindow, rp.nonceCache +} diff --git a/pkg/api/server.go b/pkg/api/server.go index 05caffa08..86007649d 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -570,6 +570,77 @@ func processRequestObject( // handleWSMessage parses all incoming WS requests, identifies what type of // JSON-RPC object they may be and forwards them to the appropriate function // to handle that type of message. +// handleWSPing handles ping messages for heartbeat operation +func handleWSPing(session *melody.Session, msg []byte) bool { + if bytes.Equal(msg, []byte("ping")) { + err := session.Write([]byte("pong")) + if err != nil { + log.Error().Err(err).Msg("sending pong") + } + return true + } + return false +} + +// handleWSAuthentication handles authentication for remote WebSocket connections +func handleWSRemoteAuth(session *melody.Session, msg []byte, db *database.Database) ([]byte, error) { + client, authenticated := session.Get("client") + if !authenticated { + // First message must be authentication + err := handleWSAuthentication(session, msg, db) + if err != nil { + log.Warn().Err(err). + Str("remote_addr", session.Request.RemoteAddr). + Str("user_agent", session.Request.Header.Get("User-Agent")). + Msg("SECURITY: WebSocket authentication failed") + _ = session.Close() + } + return nil, err + } + + // Check session timeout + if isSessionExpired(session) { + log.Warn(). + Str("remote_addr", session.Request.RemoteAddr). + Msg("SECURITY: WebSocket session expired - closing connection") + _ = session.Close() + return nil, errors.New("session expired") + } + + // Update last activity time + session.Set("auth_time", time.Now()) + + // Decrypt message for authenticated remote connection + clientObj, ok := client.(*database.Client) + if !ok { + log.Error().Msg("invalid client type in session") + return nil, errors.New("invalid client type") + } + + decryptedMsg, err := handleWSDecryption(session, msg, clientObj, db) + if err != nil { + log.Error().Err(err).Msg("WebSocket decryption failed") + return nil, err + } + + return decryptedMsg, nil +} + +// sendWSResponseByType sends response based on connection type (local/remote) +func sendWSResponseByType(session *melody.Session, isLocal bool, id uuid.UUID, resp any) error { + if !isLocal { + if client, authenticated := session.Get("client"); authenticated { + clientObj, ok := client.(*database.Client) + if !ok { + log.Error().Msg("invalid client type in session") + return errors.New("invalid client type") + } + return sendWSResponseEncrypted(session, id, resp, clientObj) + } + } + return sendWSResponse(session, id, resp) +} + func handleWSMessage( methodMap *MethodMap, platform platforms.Platform, @@ -589,69 +660,33 @@ func handleWSMessage( } }() - // ping command for heartbeat operation - if bytes.Equal(msg, []byte("ping")) { - err := session.Write([]byte("pong")) - if err != nil { - log.Error().Err(err).Msg("sending pong") - } + // Handle ping/pong heartbeat + if handleWSPing(session, msg) { return } + // Determine if connection is local rawIP := strings.SplitN(session.Request.RemoteAddr, ":", 2) clientIP := net.ParseIP(rawIP[0]) isLocal := clientIP.IsLoopback() - // Handle authentication for remote connections + // Handle authentication and decryption for remote connections if !isLocal { - client, authenticated := session.Get("client") - if !authenticated { - // First message must be authentication - err := handleWSAuthentication(session, msg, db) - if err != nil { - log.Warn().Err(err). - Str("remote_addr", session.Request.RemoteAddr). - Str("user_agent", session.Request.Header.Get("User-Agent")). - Msg("SECURITY: WebSocket authentication failed") - _ = session.Close() - } - return - } - - // Check session timeout - if isSessionExpired(session) { - log.Warn(). - Str("remote_addr", session.Request.RemoteAddr). - Msg("SECURITY: WebSocket session expired - closing connection") - _ = session.Close() - return - } - - // Update last activity time - session.Set("auth_time", time.Now()) - - // Decrypt message for authenticated remote connection - clientObj, ok := client.(*database.Client) - if !ok { - log.Error().Msg("invalid client type in session") - err := sendWSError(session, uuid.Nil, JSONRPCErrorInternalError) - if err != nil { - log.Error().Err(err).Msg("failed to send WebSocket error") - } - return - } - decryptedMsg, err := handleWSDecryption(session, msg, clientObj, db) + decryptedMsg, err := handleWSRemoteAuth(session, msg, db) if err != nil { - log.Error().Err(err).Msg("WebSocket decryption failed") err := sendWSError(session, uuid.Nil, JSONRPCErrorInvalidRequest) if err != nil { - log.Error().Err(err).Msg("error sending decryption error response") + log.Error().Err(err).Msg("error sending auth error response") } return } + if decryptedMsg == nil { + return // Authentication in progress + } msg = decryptedMsg } + // Process the request env := requests.RequestEnv{ Platform: platform, Config: cfg, @@ -668,24 +703,7 @@ func handleWSMessage( log.Error().Err(err).Msg("error sending error response") } } else { - // Encrypt response for remote authenticated connections - if !isLocal { - if client, authenticated := session.Get("client"); authenticated { - clientObj, ok := client.(*database.Client) - if !ok { - log.Error().Msg("invalid client type in session") - return - } - err := sendWSResponseEncrypted(session, id, resp, clientObj) - if err != nil { - log.Error().Err(err).Msg("error sending encrypted response") - } - return - } - } - - // Send unencrypted response for localhost - err := sendWSResponse(session, id, resp) + err := sendWSResponseByType(session, isLocal, id, resp) if err != nil { log.Error().Err(err).Msg("error sending response") } @@ -786,49 +804,22 @@ func handleWSDecryption(_ *melody.Session, msg []byte, client *database.Client, return nil, fmt.Errorf("failed to re-fetch client under lock: %w", err) } + // Create replay protector from fresh client state + replayProtector := apimiddleware.NewReplayProtector(freshClient) + // Validate sequence and nonce with fresh client state - if !apimiddleware.ValidateSequenceAndNonce(freshClient, payload.Seq, payload.Nonce) { + if !replayProtector.ValidateSequenceAndNonce(payload.Seq, payload.Nonce) { return nil, errors.New("invalid sequence or replay detected") } - // Update client state with sequence and nonce - // Update nonce cache (keep last NonceCacheSize nonces) - freshClient.NonceCache = append(freshClient.NonceCache, payload.Nonce) - if len(freshClient.NonceCache) > 100 { // NonceCacheSize - freshClient.NonceCache = freshClient.NonceCache[1:] // Remove oldest - } - - // Update sequence window - if payload.Seq > freshClient.CurrentSeq { - // New highest sequence - shift window - shift := payload.Seq - freshClient.CurrentSeq - if shift >= 64 { // SequenceWindow - // Clear entire window - freshClient.SeqWindow = make([]byte, 8) - } else { - // Shift window right - for range shift { - for i := len(freshClient.SeqWindow) - 1; i > 0; i-- { - freshClient.SeqWindow[i] = (freshClient.SeqWindow[i] << 1) | (freshClient.SeqWindow[i-1] >> 7) - } - freshClient.SeqWindow[0] <<= 1 - } - } - freshClient.CurrentSeq = payload.Seq + // Update replay protector state + replayProtector.UpdateSequenceAndNonce(payload.Seq, payload.Nonce) - // Mark current sequence as processed (position 0 in window) - freshClient.SeqWindow[0] |= 1 - } else { - // Mark this sequence as processed in the window - diff := freshClient.CurrentSeq - payload.Seq - windowPos := diff % 64 // SequenceWindow - bytePos := windowPos / 8 - bitPos := windowPos % 8 - - if bytePos < uint64(len(freshClient.SeqWindow)) { - freshClient.SeqWindow[bytePos] |= (1 << bitPos) - } - } + // Get updated state for database storage + currentSeq, seqWindow, nonceCache := replayProtector.GetStateForDatabase() + freshClient.CurrentSeq = currentSeq + freshClient.SeqWindow = seqWindow + freshClient.NonceCache = nonceCache if updateErr := db.UserDB.UpdateClientSequence( freshClient.ClientID, freshClient.CurrentSeq, freshClient.SeqWindow, freshClient.NonceCache, @@ -952,7 +943,7 @@ func handlePostRequest( } w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusInternalServerError) + w.WriteHeader(http.StatusOK) _, err = w.Write(respBody) if err != nil { log.Error().Err(err).Msg("failed to write error response") diff --git a/pkg/database/userdb/clients.go b/pkg/database/userdb/clients.go index 0ac1fc174..a8c7ad50d 100644 --- a/pkg/database/userdb/clients.go +++ b/pkg/database/userdb/clients.go @@ -22,7 +22,6 @@ package userdb import ( "context" "crypto/sha256" - "crypto/subtle" "encoding/hex" "encoding/json" "errors" @@ -88,39 +87,43 @@ func (db *UserDB) GetClientByAuthToken(authToken string) (*database.Client, erro return nil, ErrNullSQL } - return db.getClientByAuthTokenConstantTime(authToken) -} + authTokenHash := hashAuthToken(authToken) + + query := ` + SELECT client_id, client_name, auth_token_hash, shared_secret, current_seq, + seq_window, nonce_cache, created_at, last_seen + FROM clients + WHERE auth_token_hash = ? + ` -func (db *UserDB) getClientByAuthTokenConstantTime(authToken string) (*database.Client, error) { - targetHash := hashAuthToken(authToken) - targetHashBytes := []byte(targetHash) + var client database.Client + var nonceCacheJSON string + var createdAt, lastSeen int64 - // Get all clients to prevent timing attacks through database query optimization - clients, err := db.GetAllClients() + err := db.sql.QueryRowContext(context.Background(), query, authTokenHash).Scan( + &client.ClientID, + &client.ClientName, + &client.AuthTokenHash, + &client.SharedSecret, + &client.CurrentSeq, + &client.SeqWindow, + &nonceCacheJSON, + &createdAt, + &lastSeen, + ) if err != nil { - return nil, fmt.Errorf("failed to get clients: %w", err) + // Use a generic error to avoid leaking information about whether the token exists + return nil, errors.New("invalid credentials") } - var foundClient *database.Client - // Use constant-time comparison for all clients - for i := range clients { - client := &clients[i] - clientHashBytes := []byte(client.AuthTokenHash) - - // Ensure both hashes are same length to prevent timing attacks - if len(targetHashBytes) == len(clientHashBytes) { - if subtle.ConstantTimeCompare(targetHashBytes, clientHashBytes) == 1 { - foundClient = client - // Don't break - continue checking all clients for constant time - } - } - } + client.CreatedAt = time.Unix(createdAt, 0) + client.LastSeen = time.Unix(lastSeen, 0) - if foundClient == nil { - return nil, errors.New("client not found") + if unmarshalErr := json.Unmarshal([]byte(nonceCacheJSON), &client.NonceCache); unmarshalErr != nil { + client.NonceCache = make([]string, 0) // Fallback to empty cache } - return foundClient, nil + return &client, nil } func (db *UserDB) GetClientByID(clientID string) (*database.Client, error) {