diff --git a/.gitignore b/.gitignore index 959b774..5c6836b 100644 --- a/.gitignore +++ b/.gitignore @@ -56,7 +56,9 @@ server/.pytest_tmp/ *.db *.sqlite *.sqlite3 +*.db-journal blackwire.db +server/_write_test.txt # ------------------------------ # C++ / CMake / Qt / MSVC diff --git a/V0.2_Update.md b/V0.2_Update.md new file mode 100644 index 0000000..447af6d --- /dev/null +++ b/V0.2_Update.md @@ -0,0 +1,131 @@ +# v0.2: Security/Privacy upgrades that won’t “horribly” hurt performance + +Below is the smallest set of changes that massively improves security while keeping the system fast and deployable. + +## 1) Add client message signatures (high impact, low cost) + +**Goal:** a compromised server can’t impersonate users to other servers (and ideally can’t impersonate locally either). + +* Each device has a long-term **signing keypair** (Ed25519). +* Every outbound message envelope includes: + + * `sender_device_pubkey` + * `signature` over a canonical form of `(conversation_id, sender, recipient, timestamp, client_message_id, ciphertext_hash, …)` +* Receiver verifies signature **before** accepting. + +**Performance:** trivial (Ed25519 signing/verifying is fast). +**Impact:** huge integrity/auth win immediately. + +--- + +## 2) Replace “sealed box per message” with a session key + ratchet (moderate cost) + +To call it “secure messaging”, you need at least forward secrecy. The standard answer is: + +### X3DH-style handshake + Double Ratchet + +* Initial key agreement uses prekeys (server stores public prekeys only) +* Session uses Double Ratchet for: + + * forward secrecy + * post-compromise recovery + * message ordering / replay defenses + +**Performance impact:** low to moderate (mostly constant-time crypto; the overhead is protocol complexity, not CPU). + +--- + +## 3) Multi-device with “kick devices” + +You already have a concept of `device_id`. Extend it: + +* Account holds list of devices: + + * `device_uid` (server-generated stable ID) + * `device_pub_sign_key` + * `device_pub_dh_key` (for sessions) + * status: active / revoked +* In client settings: + + * list devices + last seen + * “kick” = server marks revoked +* Protocol: + + * messages can target one or multiple devices + * ratchet sessions are per device-pair (common approach) + +**Performance:** depends on “encrypt-to-how-many-devices”. Usually manageable. +**Privacy:** server still sees devices exist; content remains E2EE. + +--- + +## 4) Fix federation TOFU: “open federation” but safer bootstrap + +You want open federation, so the upgrade is: + +### bind federation trust to onion identity + +If Tor is enabled: + +* Tie federation identity to onion service identity (or publish a signed binding) +* On first contact, the onion channel itself becomes the “trust root” (still not perfect against malicious onion takeover, but better than plain TOFU) + +**Performance:** none. +**Security:** prevents silent key swaps and reduces TOFU poisoning risk. + +--- + +## 5) Token/auth hardening (privacy + security without latency) + +You correctly flagged HS256 risk and replay risk. + +For v0.2: + +* Move access tokens to **asymmetric signing** (EdDSA/RS256) +* Bind tokens to device: + + * include `device_uid` in JWT claims + * require match at server on every request +* Short access token lifetime (e.g., 5–10 min) +* Refresh token rotation stays (good) +* Consider DPoP-style proof or per-request nonce if you want to harden replay further (optional) + +**Performance:** negligible. + +--- + +## 6) End-to-end delivery integrity (detect drops/reordering) + +You said you want privacy and trust minimization; but “server can drop silently” is a real UX/security problem. + +Add: + +* Per-conversation monotonic counters or hash-chains: + + * each message includes `prev_hash` (hash of previous accepted message) + * receiver detects gaps/tampering +* “Missing message” UI that doesn’t reveal plaintext, just indicates integrity issue + +**Performance:** tiny (hashing). + +--- + +## 7) Voice: secure without “a bunch of latency” + +Your current voice (PCM over WS, no E2EE) is the biggest privacy hole. + +The lowest-latency, real-world solution is: + +### Use WebRTC for media, keep your federation for signaling + +* Signaling: your existing WS/federation routes +* Media: WebRTC SRTP (low latency, jitter buffers, NAT traversal) +* For true E2EE beyond SRTP termination concerns: use WebRTC Insertable Streams (where available) or an application-layer frame encryption + +This is how modern low-latency secure voice/video is typically done. + +If you refuse WebRTC and keep WS streaming: + +* You can still E2EE frames with AEAD + sequence numbers, +* but you’ll end up re-implementing jitter buffering + congestion control badly. + So: **WebRTC is the performance-friendly choice.** diff --git a/V0.2a_Status.md b/V0.2a_Status.md new file mode 100644 index 0000000..28bde97 --- /dev/null +++ b/V0.2a_Status.md @@ -0,0 +1,82 @@ +# Blackwire v0.2a Status (as of 2026-02-23) + +This file tracks what has been implemented from the v0.2 milestones and what is deferred to v0.2b. + +## v0.2a Milestones + +### 1) Client message signatures + sender-key pinning +Status: Done + +- Added detached signing/verification support in client crypto service. +- `/api/v2/messages/send` now carries signed per-device envelopes. +- Client verifies sender signatures before accept/decrypt. +- TOFU pinning map is persisted by `sender_device_uid`; key change triggers hard integrity warning. + +### 2) True multi-device accounts + revoke (kick) +Status: Done + +- Server supports multi-device listing, active/revoked status, and revoke endpoint in `/api/v2`. +- Revoke flow invalidates device sessions and enforces immediate cutoff behavior. +- Client send path resolves all active recipient devices and mirrors to own active devices. +- Settings UI now shows device inventory and supports kick/revoke with confirmation. + +### 3) Federation trust bound to Tor onion identity +Status: Done (v0.2a scope) + +- v2 federation well-known includes binding metadata/signing key. +- Peer onboarding validates signing key against onion-derived identity binding (Tor mode path). +- Existing signed/nonce federation protections are preserved. + +### 4) Asymmetric, device-bound JWT sessions (bootstrap -> bind) +Status: Done + +- Added v2 bootstrap auth flow (`register/login` -> bootstrap token). +- Added device registration and bind-device exchange for device-bound token issuance. +- Added EdDSA token subsystem (`sub`, `did`, `type`, `iat`, `exp`, `jti`) with refresh rotation. +- v2 auth dependencies enforce active device ownership on protected routes. + +### 5) End-to-end delivery integrity chain + integrity UI signaling +Status: Done + +- v2 message events store sender chain fields (`sender_prev_hash`, `sender_chain_hash`). +- Client computes/signs canonical message material and validates chain continuity. +- Integrity mismatches (key change/signature/chain gap) are surfaced as warnings. +- UI wiring now includes integrity warning banner + settings integrity status area. + +### 6) Parallel `/api/v2` with `/api/v1` compatibility +Status: Done + +- Added full `/api/v2` router set and WS endpoint without removing `/api/v1`. +- Added v2 message tables and flow without destructive rewrite of v1 tables. +- v1 compatibility path remains in place. + +### 7) Tests for v2 + v1 regression retention +Status: Mostly done (server complete, client pending local execution) + +- Added server integration coverage for v2 security core behavior. +- Existing v1 tests are still passing in current run. +- Client tests were expanded (crypto + serialization updates), but full local execution is still blocked in this environment because `cmake` is unavailable. + +## What is Left for v0.2b + +### A) X3DH + Double Ratchet migration +Status: Not started (deferred by plan) + +Planned next: +- Prekey publishing/consumption model. +- Session establishment via X3DH. +- Double Ratchet per device-pair with replay/ordering handling. +- Migration path from sealed-box-only v0.2a envelopes. + +### B) WebRTC media migration +Status: Not started (deferred by plan) + +Planned next: +- Keep signaling on existing control plane. +- Move voice media off raw WS PCM to WebRTC media path. +- Add media-plane E2EE strategy compatible with federation design. + +## Notes + +- v0.2a core security milestones are implemented in code and API surface. +- Remaining work to fully close v0.2a is operational validation on a machine with CMake/Qt test toolchain available. diff --git a/V0.2b_Status.md b/V0.2b_Status.md new file mode 100644 index 0000000..2e1e587 --- /dev/null +++ b/V0.2b_Status.md @@ -0,0 +1,72 @@ +# Blackwire v0.2b Status + +## Completed in this implementation pass + +### v0.2b1 (X3DH/Double Ratchet rollout scaffolding + dual-stack contracts) +- Added ratchet/prekey schema migration: + - `server/migrations/versions/20260223_0006_v2_ratchet_core.py` + - New tables: `device_signed_prekeys`, `device_one_time_prekeys`, `ratchet_sessions`, `ratchet_skipped_keys` + - Extended `message_events` with indexed `encryption_mode` +- Added v2 prekey contracts and service: + - `POST /api/v2/keys/prekeys/upload` + - `GET /api/v2/users/resolve-prekeys` + - `GET /api/v2/federation/users/{username}/prekeys` +- Extended v2 device resolution contract: + - Per-device `supported_message_modes` in local and federation device responses +- Extended v2 messaging contract: + - `encryption_mode` accepted and persisted (`sealedbox_v0_2a`, `ratchet_v0_2b1`) + - Ratchet envelope fields supported with typed schema validation: `ratchet_header`, optional `ratchet_init` + - Policy enforcement: + - `BLACKWIRE_ENABLE_RATCHET_V2B1` + - `BLACKWIRE_RATCHET_REQUIRE_FOR_LOCAL` + - `BLACKWIRE_RATCHET_REQUIRE_FOR_FEDERATION` + - Mode-specific metrics counters emitted + - Additional ratchet metrics wired: + - `ratchet.session.established` + - `ratchet.prekey.opk.exhausted` +- Client wiring for dual-stack behavior: + - DTO/API support for prekey upload/resolve and `encryption_mode` + - Send path negotiates ratchet mode by capability + prekey availability, else sealed-box fallback + - Prekey upload on device setup/re-bind +- Full cryptographic implementation of X3DH + Double Ratchet state transitions: + - Real root/send/recv chain key evolution + - Skipped-key material encryption lifecycle + - Post-compromise recovery semantics +- End-to-end ratchet decrypt path replacing sealed-box fallback behavior + + +### v0.2b2 (WebRTC signaling migration scaffolding) +- Added config gating: + - `BLACKWIRE_ENABLE_WEBRTC_V2B2` + - `BLACKWIRE_WEBRTC_ICE_SERVERS_JSON` startup validation + - `BLACKWIRE_ENABLE_LEGACY_CALL_AUDIO_WS` +- Added v2 WS signaling events: + - `call.webrtc.offer` + - `call.webrtc.answer` + - `call.webrtc.ice` +- Added v2 federation relay endpoints: + - `POST /api/v2/federation/calls/webrtc-offer` + - `POST /api/v2/federation/calls/webrtc-answer` + - `POST /api/v2/federation/calls/webrtc-ice` +- Added call schema fields for forward compatibility: + - `call_schema_version`, `call_mode`, `max_participants` +- Added WebRTC metadata propagation in `call.accepted` payloads: + - `call_schema_version`, `call_mode`, `max_participants`, `ice_servers` +- Client WS and controller wiring: + - New DTO/events/methods for WebRTC signaling + - UI state updates for signaling progress (`reason` changes/diagnostics), including ICE metadata diagnostics +- Real `libwebrtc` media engine integration in Qt client +- Actual SDP/ICE generation, candidate gathering, and RTP media track transport +- TURN/STUN runtime integration and media-failure recovery logic +- Decommissioning WS PCM transport after stabilization window + + +## Tests added and executed +- Added integration tests: + - `server/tests/integration/test_v2_prekey_upload_and_resolve` (in `test_v2_security_core.py`) + - `server/tests/integration/test_v2_ratchet_send_rejected_when_feature_disabled` (in `test_v2_security_core.py`) + - `server/tests/integration/test_v2_webrtc_signaling.py` +- Executed and passing: + - `server/tests/integration/test_v2_security_core.py` + - `server/tests/integration/test_v2_webrtc_signaling.py` + - Selected legacy voice tests passed (`test_offer_accept_audio_end`, `test_busy_and_invalid_audio_errors`) diff --git a/client-cpp-gui/include/blackwire/api/qt_api_client.hpp b/client-cpp-gui/include/blackwire/api/qt_api_client.hpp index be063df..8177bd8 100644 --- a/client-cpp-gui/include/blackwire/api/qt_api_client.hpp +++ b/client-cpp-gui/include/blackwire/api/qt_api_client.hpp @@ -31,26 +31,100 @@ class QtApiClient final : public QObject, public IApiClient { UserOut Me(const std::string& base_url, const std::string& access_token) override; - DeviceOut RegisterDevice( + AuthResponse RegisterDevice( const std::string& base_url, - const std::string& access_token, + const std::string& bootstrap_token, const DeviceRegisterRequest& request) override; + AuthResponse BindDevice( + const std::string& base_url, + const std::string& bootstrap_token, + const std::string& device_uid, + const std::string& nonce, + long long timestamp_ms, + const std::string& proof_signature_b64) override; + + std::vector ListDevices( + const std::string& base_url, + const std::string& access_token) override; + + DeviceOut RevokeDevice( + const std::string& base_url, + const std::string& access_token, + const std::string& device_uid) override; + UserDeviceLookup GetUserDevice( const std::string& base_url, const std::string& access_token, const std::string& peer_address) override; + ResolvePrekeysResponse ResolvePrekeys( + const std::string& base_url, + const std::string& access_token, + const std::string& peer_address) override; + + PrekeyUploadResponse UploadPrekeys( + const std::string& base_url, + const std::string& access_token, + const PrekeyUploadRequest& request) override; + + PresenceSetResponse SetPresenceStatus( + const std::string& base_url, + const std::string& access_token, + const PresenceSetRequest& request) override; + + PresenceResolveResponse ResolvePresence( + const std::string& base_url, + const std::string& access_token, + const PresenceResolveRequest& request) override; + ConversationOut CreateDm( const std::string& base_url, const std::string& access_token, const std::string& peer_address, const std::string& peer_username) override; + ConversationOut CreateGroup( + const std::string& base_url, + const std::string& access_token, + const CreateGroupConversationRequest& request) override; + std::vector ListConversations( const std::string& base_url, const std::string& access_token) override; + std::vector ListConversationMembers( + const std::string& base_url, + const std::string& access_token, + const std::string& conversation_id) override; + + std::vector InviteConversationMembers( + const std::string& base_url, + const std::string& access_token, + const std::string& conversation_id, + const GroupInviteRequest& request) override; + + ConversationOut RenameConversationGroup( + const std::string& base_url, + const std::string& access_token, + const std::string& conversation_id, + const GroupRenameRequest& request) override; + + ConversationMemberOut AcceptConversationInvite( + const std::string& base_url, + const std::string& access_token, + const std::string& conversation_id) override; + + ConversationMemberOut LeaveConversationGroup( + const std::string& base_url, + const std::string& access_token, + const std::string& conversation_id) override; + + ConversationRecipientsOut GetConversationRecipients( + const std::string& base_url, + const std::string& access_token, + const std::string& conversation_id) override; + std::vector ListMessages( const std::string& base_url, const std::string& access_token, diff --git a/client-cpp-gui/include/blackwire/audio/qt_audio_call_engine.hpp b/client-cpp-gui/include/blackwire/audio/qt_audio_call_engine.hpp index bf4c963..06f0a6d 100644 --- a/client-cpp-gui/include/blackwire/audio/qt_audio_call_engine.hpp +++ b/client-cpp-gui/include/blackwire/audio/qt_audio_call_engine.hpp @@ -51,7 +51,10 @@ class QtAudioCallEngine final : public QObject, public IAudioCallEngine { static constexpr int kSampleRate = 16000; static constexpr int kFrameBytes = 640; - static constexpr int kPlaybackQueueLimit = 200; + static constexpr int kSamplesPerFrame = kFrameBytes / 2; + static constexpr int kPlaybackQueueLimit = 64; + static constexpr int kPlaybackLatencyCapFrames = 24; + static constexpr int kMaxFramesMixedPerTick = 8; QAudioFormat format_; std::unique_ptr source_; diff --git a/client-cpp-gui/include/blackwire/controller/application_controller.hpp b/client-cpp-gui/include/blackwire/controller/application_controller.hpp index 0d9aff1..0afce19 100644 --- a/client-cpp-gui/include/blackwire/controller/application_controller.hpp +++ b/client-cpp-gui/include/blackwire/controller/application_controller.hpp @@ -10,6 +10,8 @@ #include #include +class QTimer; + #include "blackwire/interfaces/api_client.hpp" #include "blackwire/interfaces/audio_call_engine.hpp" #include "blackwire/interfaces/crypto_service.hpp" @@ -44,16 +46,27 @@ class ApplicationController final : public QObject { void Login(const QString& username, const QString& password); void Logout(); void SetupDevice(const QString& label); + void LoadAccountDevices(); + void RevokeDevice(const QString& device_uid); void LoadConversations(); void OpenConversationByPeer(const QString& username); void SelectConversation(const QString& conversation_id); + bool DismissDirectConversation(const QString& conversation_id); + bool LeaveGroupConversation(const QString& conversation_id); + bool CreateGroupFromCurrentDm(); + GroupInvitePickerView LoadInvitableContactsForCurrentGroup(const QString& query); + bool InviteContactsToCurrentGroup(const std::vector& peer_addresses); + bool RenameSelectedGroup(const QString& new_name); + bool IsSelectedConversationOwnerManagedGroup() const; void SendMessageToPeer(const QString& peer_username, const QString& message_text); + void SendFileToPeer(const QString& peer_username, const QString& file_path); void StartVoiceCall(); void AcceptVoiceCall(); void RejectVoiceCall(); void EndVoiceCall(); void SetCallMuted(bool muted); + void SetPresenceStatus(const QString& status); void LoadAudioDevices(); void SetPreferredAudioDevices(const QString& input_device_id, const QString& output_device_id); bool AcceptMessagesFromStrangers() const; @@ -81,26 +94,45 @@ class ApplicationController final : public QObject { const QString& preview_text); void MessageSendSucceeded(const QString& conversation_id, const QString& message_id); void ConnectionStatusChanged(const QString& status); + void UserPresenceChanged(const QString& status); void CallStateChanged(const CallStateView& state); void IncomingCallReceived(const CallStateView& state); void AudioDevicesChanged( const std::vector& input_devices, const std::vector& output_devices); + void AccountDevicesChanged(const std::vector& devices); void AudioDevicePreferenceChanged(const QString& input_device_id, const QString& output_device_id); void CallErrorOccurred(const QString& message); + void IntegrityWarningOccurred(const QString& message); void ErrorOccurred(const QString& message); private: std::string SecretNamespacePrefix() const; std::string SecretNamespaceNeedle() const; std::string SecretKey(const std::string& name) const; + std::string RequireBootstrapToken(); std::string RequireAccessToken(); std::string RequireRefreshToken(); + void SaveBootstrapToken(const TokenBundle& tokens); void SaveTokenPair(const TokenBundle& tokens); void RefreshAccessToken(); void StartRealtime(); void StopRealtime(); + void EncryptPlaintextCacheInState(); + void DecryptPlaintextCacheInState(); void PersistState(); + void RefreshPresenceCache(); + QString NormalizePresenceStatus(const QString& status) const; + QString PresenceForPeerAddress(const QString& peer_address) const; + QString ResolvePeerAddressForConversation(const std::string& conversation_id) const; + void AppendCallHistoryEntry(const QString& reason); + void AppendGroupRenameHistoryEntry( + const QString& conversation_id, + const QString& actor_address, + const QString& group_name, + const QString& dedupe_suffix); + QString FormatCallDuration(qint64 duration_ms) const; + std::vector BuildDirectCallParticipants(const QString& peer_user_id) const; void RefreshConversationList(); QString RenderMessage(const MessageOut& message, const std::string& plaintext) const; std::vector RenderThread(const std::string& conversation_id) const; @@ -108,6 +140,9 @@ class ApplicationController final : public QObject { bool IsWebSocketAuthError(const std::string& error) const; void ReauthenticateWebSocket(); void RecordDiagnostic(const QString& line); + void UploadCurrentDevicePrekeys(); + bool PreferRatchetV2b1() const; + bool DeviceSupportsMessageMode(const DeviceOut& device, const std::string& mode) const; bool ConversationExists(const std::string& conversation_id) const; const ConversationOut* FindConversation(const std::string& conversation_id) const; void RevealConversation(const std::string& conversation_id, bool select_conversation); @@ -129,7 +164,8 @@ class ApplicationController final : public QObject { void ReportCallError(const QString& message); std::optional FindConversationIdForPeer(const QString& normalized_peer) const; std::string ResolveConversationIdForPeer(const QString& normalized_peer); - DeviceOut ResolveRecipientDevice(const QString& normalized_peer); + std::vector ResolveRecipientDevices(const QString& normalized_peer); + std::vector ResolveOwnActiveDevices(); void UpsertConversationMeta( const std::string& conversation_id, const QString& peer_username, @@ -151,10 +187,24 @@ class ApplicationController final : public QObject { bool ws_reauth_in_progress_ = false; QString connection_status_ = "Disconnected"; std::deque diagnostics_; - std::unordered_map peer_device_cache_; + std::unordered_map> peer_device_cache_; + struct AttachmentPolicyCacheEntry { + qint64 attachment_inline_max_bytes = 0; + qint64 max_ciphertext_bytes = 0; + std::string source = "local"; + }; + std::unordered_map peer_attachment_policy_cache_; + std::optional local_attachment_policy_cache_; + std::unordered_map peer_presence_status_by_address_; + QString user_presence_status_ = "active"; std::map> pending_request_messages_; std::map pending_request_senders_; CallStateView call_state_; + bool call_initiated_locally_ = false; + qint64 call_started_at_ms_ = 0; + qint64 call_active_started_at_ms_ = 0; + bool pending_outgoing_end_request_ = false; + QTimer* presence_poll_timer_ = nullptr; int audio_sequence_ = 0; }; @@ -162,6 +212,8 @@ class ApplicationController final : public QObject { Q_DECLARE_METATYPE(blackwire::ConversationListItemView) Q_DECLARE_METATYPE(std::vector) +Q_DECLARE_METATYPE(blackwire::DeviceOut) +Q_DECLARE_METATYPE(std::vector) Q_DECLARE_METATYPE(blackwire::AudioDeviceOptionView) Q_DECLARE_METATYPE(std::vector) Q_DECLARE_METATYPE(blackwire::CallStateView) diff --git a/client-cpp-gui/include/blackwire/crypto/sodium_crypto_service.hpp b/client-cpp-gui/include/blackwire/crypto/sodium_crypto_service.hpp index 465d407..5bd5cc8 100644 --- a/client-cpp-gui/include/blackwire/crypto/sodium_crypto_service.hpp +++ b/client-cpp-gui/include/blackwire/crypto/sodium_crypto_service.hpp @@ -17,6 +17,14 @@ class SodiumCryptoService final : public ICryptoService { std::string DecryptWithPrivate( const std::string& private_key_b64, const std::string& ciphertext_b64) override; + std::string SignDetached( + const std::string& ed25519_private_key_b64, + const std::string& message) override; + bool VerifyDetached( + const std::string& ed25519_public_key_b64, + const std::string& message, + const std::string& signature_b64) override; + std::string Sha256(const std::string& data) override; private: static std::string EncodeBase64(const unsigned char* bytes, std::size_t length); diff --git a/client-cpp-gui/include/blackwire/interfaces/api_client.hpp b/client-cpp-gui/include/blackwire/interfaces/api_client.hpp index 649f4b5..5e5a5ad 100644 --- a/client-cpp-gui/include/blackwire/interfaces/api_client.hpp +++ b/client-cpp-gui/include/blackwire/interfaces/api_client.hpp @@ -41,26 +41,100 @@ class IApiClient { virtual UserOut Me(const std::string& base_url, const std::string& access_token) = 0; - virtual DeviceOut RegisterDevice( + virtual AuthResponse RegisterDevice( const std::string& base_url, - const std::string& access_token, + const std::string& bootstrap_token, const DeviceRegisterRequest& request) = 0; + virtual AuthResponse BindDevice( + const std::string& base_url, + const std::string& bootstrap_token, + const std::string& device_uid, + const std::string& nonce, + long long timestamp_ms, + const std::string& proof_signature_b64) = 0; + + virtual std::vector ListDevices( + const std::string& base_url, + const std::string& access_token) = 0; + + virtual DeviceOut RevokeDevice( + const std::string& base_url, + const std::string& access_token, + const std::string& device_uid) = 0; + virtual UserDeviceLookup GetUserDevice( const std::string& base_url, const std::string& access_token, const std::string& peer_address) = 0; + virtual ResolvePrekeysResponse ResolvePrekeys( + const std::string& base_url, + const std::string& access_token, + const std::string& peer_address) = 0; + + virtual PrekeyUploadResponse UploadPrekeys( + const std::string& base_url, + const std::string& access_token, + const PrekeyUploadRequest& request) = 0; + + virtual PresenceSetResponse SetPresenceStatus( + const std::string& base_url, + const std::string& access_token, + const PresenceSetRequest& request) = 0; + + virtual PresenceResolveResponse ResolvePresence( + const std::string& base_url, + const std::string& access_token, + const PresenceResolveRequest& request) = 0; + virtual ConversationOut CreateDm( const std::string& base_url, const std::string& access_token, const std::string& peer_address, const std::string& peer_username) = 0; + virtual ConversationOut CreateGroup( + const std::string& base_url, + const std::string& access_token, + const CreateGroupConversationRequest& request) = 0; + virtual std::vector ListConversations( const std::string& base_url, const std::string& access_token) = 0; + virtual std::vector ListConversationMembers( + const std::string& base_url, + const std::string& access_token, + const std::string& conversation_id) = 0; + + virtual std::vector InviteConversationMembers( + const std::string& base_url, + const std::string& access_token, + const std::string& conversation_id, + const GroupInviteRequest& request) = 0; + + virtual ConversationOut RenameConversationGroup( + const std::string& base_url, + const std::string& access_token, + const std::string& conversation_id, + const GroupRenameRequest& request) = 0; + + virtual ConversationMemberOut AcceptConversationInvite( + const std::string& base_url, + const std::string& access_token, + const std::string& conversation_id) = 0; + + virtual ConversationMemberOut LeaveConversationGroup( + const std::string& base_url, + const std::string& access_token, + const std::string& conversation_id) = 0; + + virtual ConversationRecipientsOut GetConversationRecipients( + const std::string& base_url, + const std::string& access_token, + const std::string& conversation_id) = 0; + virtual std::vector ListMessages( const std::string& base_url, const std::string& access_token, diff --git a/client-cpp-gui/include/blackwire/interfaces/crypto_service.hpp b/client-cpp-gui/include/blackwire/interfaces/crypto_service.hpp index 6c98c1a..666bebf 100644 --- a/client-cpp-gui/include/blackwire/interfaces/crypto_service.hpp +++ b/client-cpp-gui/include/blackwire/interfaces/crypto_service.hpp @@ -22,6 +22,14 @@ class ICryptoService { virtual std::string DecryptWithPrivate( const std::string& private_key_b64, const std::string& ciphertext_b64) = 0; + virtual std::string SignDetached( + const std::string& ed25519_private_key_b64, + const std::string& message) = 0; + virtual bool VerifyDetached( + const std::string& ed25519_public_key_b64, + const std::string& message, + const std::string& signature_b64) = 0; + virtual std::string Sha256(const std::string& data) = 0; }; } // namespace blackwire diff --git a/client-cpp-gui/include/blackwire/interfaces/ws_client.hpp b/client-cpp-gui/include/blackwire/interfaces/ws_client.hpp index 4a2b392..16da431 100644 --- a/client-cpp-gui/include/blackwire/interfaces/ws_client.hpp +++ b/client-cpp-gui/include/blackwire/interfaces/ws_client.hpp @@ -11,6 +11,7 @@ class IWsClient { public: using MessageHandler = std::function; using CallIncomingHandler = std::function; + using CallGroupStateHandler = std::function; using CallRingingHandler = std::function; using CallAcceptedHandler = std::function; using CallRejectedHandler = std::function; @@ -18,6 +19,10 @@ class IWsClient { using CallEndedHandler = std::function; using CallAudioHandler = std::function; using CallErrorHandler = std::function; + using CallWebRtcOfferHandler = std::function; + using CallWebRtcAnswerHandler = std::function; + using CallWebRtcIceHandler = std::function; + using GroupRenamedHandler = std::function; using ErrorHandler = std::function; using StatusHandler = std::function; @@ -26,6 +31,7 @@ class IWsClient { virtual void SetHandlers( MessageHandler on_message, CallIncomingHandler on_call_incoming, + CallGroupStateHandler on_call_group_state, CallRingingHandler on_call_ringing, CallAcceptedHandler on_call_accepted, CallRejectedHandler on_call_rejected, @@ -33,6 +39,10 @@ class IWsClient { CallEndedHandler on_call_ended, CallAudioHandler on_call_audio, CallErrorHandler on_call_error, + CallWebRtcOfferHandler on_call_webrtc_offer, + CallWebRtcAnswerHandler on_call_webrtc_answer, + CallWebRtcIceHandler on_call_webrtc_ice, + GroupRenamedHandler on_group_renamed, ErrorHandler on_error, StatusHandler on_status) = 0; virtual void Connect(const std::string& base_url, const std::string& access_token) = 0; @@ -43,6 +53,9 @@ class IWsClient { virtual void SendCallReject(const VoiceCallReject& reject) = 0; virtual void SendCallEnd(const VoiceCallEnd& end) = 0; virtual void SendCallAudioChunk(const VoiceAudioChunk& chunk) = 0; + virtual void SendCallWebRtcOffer(const VoiceCallWebRtcOffer& offer) = 0; + virtual void SendCallWebRtcAnswer(const VoiceCallWebRtcAnswer& answer) = 0; + virtual void SendCallWebRtcIce(const VoiceCallWebRtcIce& ice) = 0; }; } // namespace blackwire diff --git a/client-cpp-gui/include/blackwire/models/dto.hpp b/client-cpp-gui/include/blackwire/models/dto.hpp index 2e0b629..39297c6 100644 --- a/client-cpp-gui/include/blackwire/models/dto.hpp +++ b/client-cpp-gui/include/blackwire/models/dto.hpp @@ -1,5 +1,6 @@ #pragma once +#include #include #include @@ -21,6 +22,9 @@ struct TokenBundle { int access_expires_in = 0; std::string refresh_token; int refresh_expires_in = 0; + std::string bootstrap_token; + int bootstrap_expires_in = 0; + std::string device_uid; }; struct AuthResponse { @@ -32,14 +36,21 @@ struct DeviceRegisterRequest { std::string label; std::string ik_ed25519_pub; std::string enc_x25519_pub; + std::string pub_sign_key; + std::string pub_dh_key; }; struct DeviceOut { std::string id; + std::string device_uid; std::string user_id; std::string label; std::string ik_ed25519_pub; std::string enc_x25519_pub; + std::string status; + std::vector supported_message_modes; + std::string last_seen_at; + std::string revoked_at; std::string created_at; }; @@ -47,6 +58,10 @@ struct UserDeviceLookup { std::string username; std::string peer_address; DeviceOut device; + std::vector devices; + long long attachment_inline_max_bytes = 0; + long long max_ciphertext_bytes = 0; + std::string attachment_policy_source = "local"; }; struct ConversationOut { @@ -59,15 +74,56 @@ struct ConversationOut { std::string peer_username; std::string peer_server_onion; std::string peer_address; + std::string conversation_type = "direct"; + std::string group_uid; + std::string group_name; + int member_count = 0; + std::string membership_state = "none"; + bool can_manage_members = false; + std::string origin_server_onion; + std::string owner_address; +}; + +struct CreateGroupConversationRequest { + std::string name; + std::vector member_addresses; +}; + +struct GroupInviteRequest { + std::vector member_addresses; +}; + +struct GroupRenameRequest { + std::string name; +}; + +struct ConversationMemberOut { + std::string id; + std::string member_user_id; + std::string member_address; + std::string member_server_onion; + std::string role = "member"; + std::string status = "invited"; + std::string invited_by_address; + std::string invited_at; + std::string joined_at; + std::string left_at; + std::string updated_at; }; struct CipherEnvelope { int version = 1; std::string alg = "libsodium-sealedbox-v1"; std::string recipient_device_id; + std::string recipient_user_address; + std::string recipient_device_uid; std::string ciphertext_b64; std::string aad_b64; std::string client_message_id; + std::string signature_b64; + std::string sender_device_pubkey; + nlohmann::json ratchet_header = nullptr; + nlohmann::json ratchet_init = nullptr; }; struct MessageOut { @@ -76,14 +132,111 @@ struct MessageOut { std::string sender_user_id; std::string sender_address; std::string sender_device_id; + std::string sender_device_uid; + std::string sender_device_pubkey; + std::string encryption_mode = "sealedbox_v0_2a"; std::string client_message_id; + long long sent_at_ms = 0; + std::string sender_prev_hash; + std::string sender_chain_hash; CipherEnvelope envelope; std::string created_at; }; struct MessageSendRequest { std::string conversation_id; + std::string encryption_mode = "sealedbox_v0_2a"; CipherEnvelope envelope; + std::string client_message_id; + long long sent_at_ms = 0; + std::string sender_prev_hash; + std::string sender_chain_hash; + std::vector envelopes; +}; + +struct SignedPrekeyUpload { + int key_id = 1; + std::string pub_x25519_b64; + std::string sig_by_device_sign_key_b64; + std::string expires_at; +}; + +struct OneTimePrekeyUpload { + int key_id = 1; + std::string pub_x25519_b64; +}; + +struct PrekeyUploadRequest { + SignedPrekeyUpload signed_prekey; + std::vector one_time_prekeys; +}; + +struct PrekeyUploadResponse { + int uploaded_signed_prekey_key_id = 0; + int accepted_one_time_prekeys = 0; +}; + +struct SignedPrekeyOut { + int key_id = 0; + std::string pub_x25519_b64; + std::string sig_by_device_sign_key_b64; + std::string expires_at; +}; + +struct OneTimePrekeyOut { + int key_id = 0; + std::string pub_x25519_b64; +}; + +struct ResolvedPrekeyDevice { + std::string device_uid; + std::string pub_sign_key; + std::string pub_dh_key; + std::vector supported_message_modes; + bool opk_missing = false; + std::optional signed_prekey; + std::optional one_time_prekey; +}; + +struct ResolvePrekeysResponse { + std::string username; + std::string peer_address; + std::vector devices; +}; + +struct ConversationRecipientDeviceOut { + std::string member_address; + std::string member_status = "active"; + DeviceOut device; + std::optional prekey; +}; + +struct ConversationRecipientsOut { + std::string conversation_id; + std::string conversation_type = "direct"; + std::vector recipients; +}; + +struct PresenceSetRequest { + std::string status; +}; + +struct PresenceSetResponse { + std::string status; +}; + +struct PresenceResolveRequest { + std::vector peer_addresses; +}; + +struct PresencePeerOut { + std::string peer_address; + std::string status; +}; + +struct PresenceResolveResponse { + std::string self_status; + std::vector peers; }; struct MessageSendResponse { @@ -92,6 +245,7 @@ struct MessageSendResponse { }; struct WsEventMessageNew { + std::string copy_id; MessageOut message; }; @@ -119,6 +273,38 @@ struct VoiceAudioChunk { std::string pcm_b64; }; +struct VoiceCallWebRtcOffer { + std::string call_id; + std::string sdp; + int call_schema_version = 1; + std::string call_mode = "webrtc"; + int max_participants = 2; + std::string target_user_address; + std::string source_user_address; +}; + +struct VoiceCallWebRtcAnswer { + std::string call_id; + std::string sdp; + int call_schema_version = 1; + std::string call_mode = "webrtc"; + int max_participants = 2; + std::string target_user_address; + std::string source_user_address; +}; + +struct VoiceCallWebRtcIce { + std::string call_id; + std::string candidate; + std::string sdp_mid; + int sdp_mline_index = -1; + int call_schema_version = 1; + std::string call_mode = "webrtc"; + int max_participants = 2; + std::string target_user_address; + std::string source_user_address; +}; + struct WsEventCallIncoming { std::string call_id; std::string conversation_id; @@ -138,6 +324,10 @@ struct WsEventCallAccepted { std::string conversation_id; std::string peer_user_id; std::string peer_user_address; + int call_schema_version = 1; + std::string call_mode; + int max_participants = 2; + nlohmann::json ice_servers = nlohmann::json::array(); }; struct WsEventCallRejected { @@ -157,6 +347,21 @@ struct WsEventCallEnded { std::string by_user_id; }; +struct WsEventCallGroupParticipant { + std::string member_address; + std::string state; +}; + +struct WsEventCallGroupState { + std::string call_id; + std::string conversation_id; + std::string group_uid; + std::string state; + std::string call_mode; + int max_participants = 2; + std::vector participants; +}; + struct WsEventCallAudio { std::string call_id; std::string from_user_id; @@ -170,6 +375,64 @@ struct WsEventCallError { std::string detail; }; +struct WsEventCallWebRtcOffer { + std::string call_id; + std::string sdp; + int call_schema_version = 1; + std::string call_mode = "webrtc"; + int max_participants = 2; + std::string from_user_address; + std::string source_user_address; + std::string target_user_address; +}; + +struct WsEventCallWebRtcAnswer { + std::string call_id; + std::string sdp; + int call_schema_version = 1; + std::string call_mode = "webrtc"; + int max_participants = 2; + std::string from_user_address; + std::string source_user_address; + std::string target_user_address; +}; + +struct WsEventCallWebRtcIce { + std::string call_id; + std::string candidate; + std::string sdp_mid; + int sdp_mline_index = -1; + int call_schema_version = 1; + std::string call_mode = "webrtc"; + int max_participants = 2; + std::string from_user_address; + std::string source_user_address; + std::string target_user_address; +}; + +struct WsEventGroupRenamed { + std::string conversation_id; + std::string group_uid; + std::string group_name; + std::string actor_address; + int event_seq = 0; +}; + +inline std::string JsonStringOrDefault( + const nlohmann::json& j, + const char* key, + const std::string& fallback = "") { + const auto it = j.find(key); + if (it == j.end() || it->is_null()) { + return fallback; + } + try { + return it->get(); + } catch (...) { + return fallback; + } +} + inline void to_json(nlohmann::json& j, const UserOut& v) { j = nlohmann::json{ {"id", v.id}, @@ -195,15 +458,21 @@ inline void to_json(nlohmann::json& j, const TokenBundle& v) { {"access_expires_in", v.access_expires_in}, {"refresh_token", v.refresh_token}, {"refresh_expires_in", v.refresh_expires_in}, + {"bootstrap_token", v.bootstrap_token}, + {"bootstrap_expires_in", v.bootstrap_expires_in}, + {"device_uid", v.device_uid}, }; } inline void from_json(const nlohmann::json& j, TokenBundle& v) { - j.at("access_token").get_to(v.access_token); + v.access_token = j.value("access_token", ""); v.token_type = j.value("token_type", "bearer"); - j.at("access_expires_in").get_to(v.access_expires_in); - j.at("refresh_token").get_to(v.refresh_token); - j.at("refresh_expires_in").get_to(v.refresh_expires_in); + v.access_expires_in = j.value("access_expires_in", 0); + v.refresh_token = j.value("refresh_token", ""); + v.refresh_expires_in = j.value("refresh_expires_in", 0); + v.bootstrap_token = j.value("bootstrap_token", ""); + v.bootstrap_expires_in = j.value("bootstrap_expires_in", 0); + v.device_uid = j.value("device_uid", ""); } inline void to_json(nlohmann::json& j, const AuthResponse& v) { @@ -220,43 +489,86 @@ inline void to_json(nlohmann::json& j, const DeviceRegisterRequest& v) { {"label", v.label}, {"ik_ed25519_pub", v.ik_ed25519_pub}, {"enc_x25519_pub", v.enc_x25519_pub}, + {"pub_sign_key", v.pub_sign_key.empty() ? v.ik_ed25519_pub : v.pub_sign_key}, + {"pub_dh_key", v.pub_dh_key.empty() ? v.enc_x25519_pub : v.pub_dh_key}, }; } inline void from_json(const nlohmann::json& j, DeviceRegisterRequest& v) { j.at("label").get_to(v.label); - j.at("ik_ed25519_pub").get_to(v.ik_ed25519_pub); - j.at("enc_x25519_pub").get_to(v.enc_x25519_pub); + v.ik_ed25519_pub = j.value("ik_ed25519_pub", j.value("pub_sign_key", "")); + v.enc_x25519_pub = j.value("enc_x25519_pub", j.value("pub_dh_key", "")); + v.pub_sign_key = j.value("pub_sign_key", v.ik_ed25519_pub); + v.pub_dh_key = j.value("pub_dh_key", v.enc_x25519_pub); } inline void to_json(nlohmann::json& j, const DeviceOut& v) { j = nlohmann::json{ - {"id", v.id}, + {"id", v.id.empty() ? v.device_uid : v.id}, + {"device_uid", v.device_uid.empty() ? v.id : v.device_uid}, {"user_id", v.user_id}, {"label", v.label}, {"ik_ed25519_pub", v.ik_ed25519_pub}, {"enc_x25519_pub", v.enc_x25519_pub}, + {"pub_sign_key", v.ik_ed25519_pub}, + {"pub_dh_key", v.enc_x25519_pub}, + {"status", v.status}, + {"supported_message_modes", v.supported_message_modes}, + {"last_seen_at", v.last_seen_at}, + {"revoked_at", v.revoked_at.empty() ? nlohmann::json(nullptr) : nlohmann::json(v.revoked_at)}, {"created_at", v.created_at}, }; } inline void from_json(const nlohmann::json& j, DeviceOut& v) { - j.at("id").get_to(v.id); + v.id = j.value("id", j.value("device_uid", "")); + v.device_uid = j.value("device_uid", v.id); j.at("user_id").get_to(v.user_id); j.at("label").get_to(v.label); - j.at("ik_ed25519_pub").get_to(v.ik_ed25519_pub); - j.at("enc_x25519_pub").get_to(v.enc_x25519_pub); + v.ik_ed25519_pub = j.value("ik_ed25519_pub", j.value("pub_sign_key", "")); + v.enc_x25519_pub = j.value("enc_x25519_pub", j.value("pub_dh_key", "")); + v.status = j.value("status", "active"); + v.supported_message_modes = j.value("supported_message_modes", std::vector{"sealedbox_v0_2a"}); + v.last_seen_at = j.value("last_seen_at", ""); + if (j.contains("revoked_at") && !j.at("revoked_at").is_null()) { + v.revoked_at = j.value("revoked_at", ""); + } else { + v.revoked_at.clear(); + } j.at("created_at").get_to(v.created_at); } inline void to_json(nlohmann::json& j, const UserDeviceLookup& v) { - j = nlohmann::json{{"username", v.username}, {"peer_address", v.peer_address}, {"device", v.device}}; + j = nlohmann::json{ + {"username", v.username}, + {"peer_address", v.peer_address}, + {"device", v.device}, + {"devices", v.devices.empty() ? nlohmann::json::array({v.device}) : nlohmann::json(v.devices)}, + {"attachment_inline_max_bytes", v.attachment_inline_max_bytes}, + {"max_ciphertext_bytes", v.max_ciphertext_bytes}, + {"attachment_policy_source", v.attachment_policy_source}, + }; } inline void from_json(const nlohmann::json& j, UserDeviceLookup& v) { j.at("username").get_to(v.username); v.peer_address = j.value("peer_address", ""); - j.at("device").get_to(v.device); + if (j.contains("device")) { + j.at("device").get_to(v.device); + } + if (j.contains("devices")) { + j.at("devices").get_to(v.devices); + } else if (!v.device.id.empty() || !v.device.device_uid.empty()) { + v.devices = {v.device}; + } else { + v.devices.clear(); + } + if (v.device.id.empty() && !v.devices.empty()) { + v.device = v.devices.front(); + } + v.attachment_inline_max_bytes = j.value("attachment_inline_max_bytes", 0LL); + v.max_ciphertext_bytes = j.value("max_ciphertext_bytes", 0LL); + v.attachment_policy_source = j.value("attachment_policy_source", "local"); } inline void to_json(nlohmann::json& j, const ConversationOut& v) { @@ -270,6 +582,14 @@ inline void to_json(nlohmann::json& j, const ConversationOut& v) { {"peer_username", v.peer_username}, {"peer_server_onion", v.peer_server_onion}, {"peer_address", v.peer_address}, + {"conversation_type", v.conversation_type}, + {"group_uid", v.group_uid.empty() ? nlohmann::json(nullptr) : nlohmann::json(v.group_uid)}, + {"group_name", v.group_name}, + {"member_count", v.member_count}, + {"membership_state", v.membership_state}, + {"can_manage_members", v.can_manage_members}, + {"origin_server_onion", v.origin_server_onion}, + {"owner_address", v.owner_address}, }; } @@ -283,9 +603,195 @@ inline void from_json(const nlohmann::json& j, ConversationOut& v) { v.peer_username = j.value("peer_username", ""); v.peer_server_onion = j.value("peer_server_onion", ""); v.peer_address = j.value("peer_address", ""); + v.conversation_type = j.value("conversation_type", "direct"); + if (j.contains("group_uid") && !j.at("group_uid").is_null()) { + v.group_uid = j.value("group_uid", ""); + } else { + v.group_uid.clear(); + } + v.group_name = j.value("group_name", ""); + v.member_count = j.value("member_count", 0); + v.membership_state = j.value("membership_state", "none"); + v.can_manage_members = j.value("can_manage_members", false); + v.origin_server_onion = j.value("origin_server_onion", ""); + v.owner_address = j.value("owner_address", ""); +} + +inline void to_json(nlohmann::json& j, const CreateGroupConversationRequest& v) { + j = nlohmann::json{ + {"name", v.name}, + {"member_addresses", v.member_addresses}, + }; +} + +inline void from_json(const nlohmann::json& j, CreateGroupConversationRequest& v) { + v.name = j.value("name", ""); + v.member_addresses = j.value("member_addresses", std::vector{}); +} + +inline void to_json(nlohmann::json& j, const GroupInviteRequest& v) { + j = nlohmann::json{ + {"member_addresses", v.member_addresses}, + }; +} + +inline void from_json(const nlohmann::json& j, GroupInviteRequest& v) { + v.member_addresses = j.value("member_addresses", std::vector{}); +} + +inline void to_json(nlohmann::json& j, const GroupRenameRequest& v) { + j = nlohmann::json{ + {"name", v.name}, + }; +} + +inline void from_json(const nlohmann::json& j, GroupRenameRequest& v) { + v.name = j.value("name", ""); +} + +inline void to_json(nlohmann::json& j, const ConversationMemberOut& v) { + j = nlohmann::json{ + {"id", v.id}, + {"member_user_id", v.member_user_id.empty() ? nlohmann::json(nullptr) : nlohmann::json(v.member_user_id)}, + {"member_address", v.member_address}, + {"member_server_onion", v.member_server_onion}, + {"role", v.role}, + {"status", v.status}, + {"invited_by_address", v.invited_by_address}, + {"invited_at", v.invited_at}, + {"joined_at", v.joined_at.empty() ? nlohmann::json(nullptr) : nlohmann::json(v.joined_at)}, + {"left_at", v.left_at.empty() ? nlohmann::json(nullptr) : nlohmann::json(v.left_at)}, + {"updated_at", v.updated_at}, + }; +} + +inline void from_json(const nlohmann::json& j, ConversationMemberOut& v) { + v.id = JsonStringOrDefault(j, "id"); + v.member_user_id = JsonStringOrDefault(j, "member_user_id"); + v.member_address = JsonStringOrDefault(j, "member_address"); + v.member_server_onion = JsonStringOrDefault(j, "member_server_onion"); + v.role = JsonStringOrDefault(j, "role", "member"); + v.status = JsonStringOrDefault(j, "status", "invited"); + v.invited_by_address = JsonStringOrDefault(j, "invited_by_address"); + v.invited_at = JsonStringOrDefault(j, "invited_at"); + v.joined_at = JsonStringOrDefault(j, "joined_at"); + v.left_at = JsonStringOrDefault(j, "left_at"); + v.updated_at = JsonStringOrDefault(j, "updated_at"); +} + +inline void to_json(nlohmann::json& j, const ConversationRecipientDeviceOut& v) { + nlohmann::json prekey_json = nlohmann::json(nullptr); + if (v.prekey.has_value()) { + const auto& prekey = v.prekey.value(); + prekey_json = nlohmann::json{ + {"device_uid", prekey.device_uid}, + {"pub_sign_key", prekey.pub_sign_key}, + {"pub_dh_key", prekey.pub_dh_key}, + {"supported_message_modes", prekey.supported_message_modes}, + {"opk_missing", prekey.opk_missing}, + }; + if (prekey.signed_prekey.has_value()) { + prekey_json["signed_prekey"] = nlohmann::json{ + {"key_id", prekey.signed_prekey->key_id}, + {"pub_x25519_b64", prekey.signed_prekey->pub_x25519_b64}, + {"sig_by_device_sign_key_b64", prekey.signed_prekey->sig_by_device_sign_key_b64}, + {"expires_at", prekey.signed_prekey->expires_at}, + }; + } else { + prekey_json["signed_prekey"] = nullptr; + } + if (prekey.one_time_prekey.has_value()) { + prekey_json["one_time_prekey"] = nlohmann::json{ + {"key_id", prekey.one_time_prekey->key_id}, + {"pub_x25519_b64", prekey.one_time_prekey->pub_x25519_b64}, + }; + } else { + prekey_json["one_time_prekey"] = nullptr; + } + } + j = nlohmann::json{ + {"member_address", v.member_address}, + {"member_status", v.member_status}, + {"device", v.device}, + {"prekey", prekey_json}, + }; +} + +inline void from_json(const nlohmann::json& j, ConversationRecipientDeviceOut& v) { + v.member_address = j.value("member_address", ""); + v.member_status = j.value("member_status", "active"); + if (j.contains("device") && j.at("device").is_object()) { + j.at("device").get_to(v.device); + } else { + v.device = DeviceOut{}; + } + if (j.contains("prekey") && !j.at("prekey").is_null()) { + const auto& prekey_json = j.at("prekey"); + ResolvedPrekeyDevice prekey; + prekey.device_uid = prekey_json.value("device_uid", ""); + prekey.pub_sign_key = prekey_json.value("pub_sign_key", ""); + prekey.pub_dh_key = prekey_json.value("pub_dh_key", ""); + prekey.supported_message_modes = + prekey_json.value("supported_message_modes", std::vector{"sealedbox_v0_2a"}); + prekey.opk_missing = prekey_json.value("opk_missing", false); + if (prekey_json.contains("signed_prekey") && !prekey_json.at("signed_prekey").is_null()) { + SignedPrekeyOut signed_prekey; + const auto& signed_prekey_json = prekey_json.at("signed_prekey"); + signed_prekey.key_id = signed_prekey_json.value("key_id", 0); + signed_prekey.pub_x25519_b64 = signed_prekey_json.value("pub_x25519_b64", ""); + signed_prekey.sig_by_device_sign_key_b64 = signed_prekey_json.value("sig_by_device_sign_key_b64", ""); + signed_prekey.expires_at = signed_prekey_json.value("expires_at", ""); + prekey.signed_prekey = signed_prekey; + } else { + prekey.signed_prekey.reset(); + } + if (prekey_json.contains("one_time_prekey") && !prekey_json.at("one_time_prekey").is_null()) { + OneTimePrekeyOut one_time_prekey; + const auto& one_time_prekey_json = prekey_json.at("one_time_prekey"); + one_time_prekey.key_id = one_time_prekey_json.value("key_id", 0); + one_time_prekey.pub_x25519_b64 = one_time_prekey_json.value("pub_x25519_b64", ""); + prekey.one_time_prekey = one_time_prekey; + } else { + prekey.one_time_prekey.reset(); + } + v.prekey = prekey; + } else { + v.prekey.reset(); + } +} + +inline void to_json(nlohmann::json& j, const ConversationRecipientsOut& v) { + j = nlohmann::json{ + {"conversation_id", v.conversation_id}, + {"conversation_type", v.conversation_type}, + {"recipients", v.recipients}, + }; +} + +inline void from_json(const nlohmann::json& j, ConversationRecipientsOut& v) { + v.conversation_id = j.value("conversation_id", ""); + v.conversation_type = j.value("conversation_type", "direct"); + v.recipients = j.value("recipients", std::vector{}); } inline void to_json(nlohmann::json& j, const CipherEnvelope& v) { + j = nlohmann::json::object(); + if (!v.recipient_user_address.empty() || !v.recipient_device_uid.empty() || !v.signature_b64.empty()) { + j["recipient_user_address"] = v.recipient_user_address; + j["recipient_device_uid"] = v.recipient_device_uid.empty() ? v.recipient_device_id : v.recipient_device_uid; + j["ciphertext_b64"] = v.ciphertext_b64; + j["aad_b64"] = v.aad_b64.empty() ? nlohmann::json(nullptr) : nlohmann::json(v.aad_b64); + j["signature_b64"] = v.signature_b64; + j["sender_device_pubkey"] = v.sender_device_pubkey; + if (!v.ratchet_header.is_null()) { + j["ratchet_header"] = v.ratchet_header; + } + if (!v.ratchet_init.is_null()) { + j["ratchet_init"] = v.ratchet_init; + } + return; + } + j = nlohmann::json{ {"version", v.version}, {"alg", v.alg}, @@ -299,14 +805,20 @@ inline void to_json(nlohmann::json& j, const CipherEnvelope& v) { inline void from_json(const nlohmann::json& j, CipherEnvelope& v) { v.version = j.value("version", 1); v.alg = j.value("alg", "libsodium-sealedbox-v1"); - j.at("recipient_device_id").get_to(v.recipient_device_id); + v.recipient_device_id = j.value("recipient_device_id", j.value("recipient_device_uid", "")); + v.recipient_device_uid = j.value("recipient_device_uid", v.recipient_device_id); + v.recipient_user_address = j.value("recipient_user_address", ""); j.at("ciphertext_b64").get_to(v.ciphertext_b64); if (j.contains("aad_b64") && !j.at("aad_b64").is_null()) { j.at("aad_b64").get_to(v.aad_b64); } else { v.aad_b64.clear(); } - j.at("client_message_id").get_to(v.client_message_id); + v.client_message_id = j.value("client_message_id", ""); + v.signature_b64 = j.value("signature_b64", ""); + v.sender_device_pubkey = j.value("sender_device_pubkey", ""); + v.ratchet_header = j.value("ratchet_header", nlohmann::json(nullptr)); + v.ratchet_init = j.value("ratchet_init", nlohmann::json(nullptr)); } inline void to_json(nlohmann::json& j, const MessageOut& v) { @@ -315,8 +827,14 @@ inline void to_json(nlohmann::json& j, const MessageOut& v) { {"conversation_id", v.conversation_id}, {"sender_user_id", v.sender_user_id}, {"sender_address", v.sender_address}, - {"sender_device_id", v.sender_device_id}, + {"sender_device_id", v.sender_device_id.empty() ? v.sender_device_uid : v.sender_device_id}, + {"sender_device_uid", v.sender_device_uid.empty() ? v.sender_device_id : v.sender_device_uid}, + {"sender_device_pubkey", v.sender_device_pubkey}, + {"encryption_mode", v.encryption_mode}, {"client_message_id", v.client_message_id}, + {"sent_at_ms", v.sent_at_ms}, + {"sender_prev_hash", v.sender_prev_hash}, + {"sender_chain_hash", v.sender_chain_hash}, {"envelope", v.envelope}, {"created_at", v.created_at}, }; @@ -327,8 +845,14 @@ inline void from_json(const nlohmann::json& j, MessageOut& v) { j.at("conversation_id").get_to(v.conversation_id); v.sender_user_id = j.value("sender_user_id", ""); v.sender_address = j.value("sender_address", ""); - j.at("sender_device_id").get_to(v.sender_device_id); + v.sender_device_id = j.value("sender_device_id", j.value("sender_device_uid", "")); + v.sender_device_uid = j.value("sender_device_uid", v.sender_device_id); + v.sender_device_pubkey = j.value("sender_device_pubkey", ""); + v.encryption_mode = j.value("encryption_mode", "sealedbox_v0_2a"); j.at("client_message_id").get_to(v.client_message_id); + v.sent_at_ms = j.value("sent_at_ms", 0LL); + v.sender_prev_hash = j.value("sender_prev_hash", ""); + v.sender_chain_hash = j.value("sender_chain_hash", ""); if (j.contains("envelope")) { j.at("envelope").get_to(v.envelope); @@ -340,12 +864,33 @@ inline void from_json(const nlohmann::json& j, MessageOut& v) { } inline void to_json(nlohmann::json& j, const MessageSendRequest& v) { + if (!v.envelopes.empty()) { + j = nlohmann::json{ + {"conversation_id", v.conversation_id}, + {"encryption_mode", v.encryption_mode}, + {"client_message_id", v.client_message_id}, + {"sent_at_ms", v.sent_at_ms}, + {"sender_prev_hash", v.sender_prev_hash}, + {"sender_chain_hash", v.sender_chain_hash}, + {"envelopes", v.envelopes}, + }; + return; + } j = nlohmann::json{{"conversation_id", v.conversation_id}, {"envelope", v.envelope}}; } inline void from_json(const nlohmann::json& j, MessageSendRequest& v) { j.at("conversation_id").get_to(v.conversation_id); - j.at("envelope").get_to(v.envelope); + v.encryption_mode = j.value("encryption_mode", "sealedbox_v0_2a"); + if (j.contains("envelopes")) { + j.at("envelopes").get_to(v.envelopes); + } else if (j.contains("envelope")) { + j.at("envelope").get_to(v.envelope); + } + v.client_message_id = j.value("client_message_id", ""); + v.sent_at_ms = j.value("sent_at_ms", 0LL); + v.sender_prev_hash = j.value("sender_prev_hash", ""); + v.sender_chain_hash = j.value("sender_chain_hash", ""); } inline void to_json(nlohmann::json& j, const MessageSendResponse& v) { @@ -357,7 +902,202 @@ inline void from_json(const nlohmann::json& j, MessageSendResponse& v) { j.at("message").get_to(v.message); } +inline void to_json(nlohmann::json& j, const SignedPrekeyUpload& v) { + j = nlohmann::json{ + {"key_id", v.key_id}, + {"pub_x25519_b64", v.pub_x25519_b64}, + {"sig_by_device_sign_key_b64", v.sig_by_device_sign_key_b64}, + {"expires_at", v.expires_at}, + }; +} + +inline void from_json(const nlohmann::json& j, SignedPrekeyUpload& v) { + v.key_id = j.value("key_id", 1); + v.pub_x25519_b64 = j.value("pub_x25519_b64", ""); + v.sig_by_device_sign_key_b64 = j.value("sig_by_device_sign_key_b64", ""); + v.expires_at = j.value("expires_at", ""); +} + +inline void to_json(nlohmann::json& j, const OneTimePrekeyUpload& v) { + j = nlohmann::json{ + {"key_id", v.key_id}, + {"pub_x25519_b64", v.pub_x25519_b64}, + }; +} + +inline void from_json(const nlohmann::json& j, OneTimePrekeyUpload& v) { + v.key_id = j.value("key_id", 1); + v.pub_x25519_b64 = j.value("pub_x25519_b64", ""); +} + +inline void to_json(nlohmann::json& j, const PrekeyUploadRequest& v) { + j = nlohmann::json{ + {"signed_prekey", v.signed_prekey}, + {"one_time_prekeys", v.one_time_prekeys}, + }; +} + +inline void from_json(const nlohmann::json& j, PrekeyUploadRequest& v) { + if (j.contains("signed_prekey")) { + j.at("signed_prekey").get_to(v.signed_prekey); + } + if (j.contains("one_time_prekeys")) { + j.at("one_time_prekeys").get_to(v.one_time_prekeys); + } +} + +inline void to_json(nlohmann::json& j, const PrekeyUploadResponse& v) { + j = nlohmann::json{ + {"uploaded_signed_prekey_key_id", v.uploaded_signed_prekey_key_id}, + {"accepted_one_time_prekeys", v.accepted_one_time_prekeys}, + }; +} + +inline void from_json(const nlohmann::json& j, PrekeyUploadResponse& v) { + v.uploaded_signed_prekey_key_id = j.value("uploaded_signed_prekey_key_id", 0); + v.accepted_one_time_prekeys = j.value("accepted_one_time_prekeys", 0); +} + +inline void to_json(nlohmann::json& j, const SignedPrekeyOut& v) { + j = nlohmann::json{ + {"key_id", v.key_id}, + {"pub_x25519_b64", v.pub_x25519_b64}, + {"sig_by_device_sign_key_b64", v.sig_by_device_sign_key_b64}, + {"expires_at", v.expires_at}, + }; +} + +inline void from_json(const nlohmann::json& j, SignedPrekeyOut& v) { + v.key_id = j.value("key_id", 0); + v.pub_x25519_b64 = j.value("pub_x25519_b64", ""); + v.sig_by_device_sign_key_b64 = j.value("sig_by_device_sign_key_b64", ""); + v.expires_at = j.value("expires_at", ""); +} + +inline void to_json(nlohmann::json& j, const OneTimePrekeyOut& v) { + j = nlohmann::json{ + {"key_id", v.key_id}, + {"pub_x25519_b64", v.pub_x25519_b64}, + }; +} + +inline void from_json(const nlohmann::json& j, OneTimePrekeyOut& v) { + v.key_id = j.value("key_id", 0); + v.pub_x25519_b64 = j.value("pub_x25519_b64", ""); +} + +inline void to_json(nlohmann::json& j, const ResolvedPrekeyDevice& v) { + j = nlohmann::json{ + {"device_uid", v.device_uid}, + {"pub_sign_key", v.pub_sign_key}, + {"pub_dh_key", v.pub_dh_key}, + {"supported_message_modes", v.supported_message_modes}, + {"opk_missing", v.opk_missing}, + }; + if (v.signed_prekey.has_value()) { + j["signed_prekey"] = v.signed_prekey.value(); + } else { + j["signed_prekey"] = nullptr; + } + if (v.one_time_prekey.has_value()) { + j["one_time_prekey"] = v.one_time_prekey.value(); + } else { + j["one_time_prekey"] = nullptr; + } +} + +inline void from_json(const nlohmann::json& j, ResolvedPrekeyDevice& v) { + v.device_uid = j.value("device_uid", ""); + v.pub_sign_key = j.value("pub_sign_key", ""); + v.pub_dh_key = j.value("pub_dh_key", ""); + v.supported_message_modes = j.value("supported_message_modes", std::vector{}); + v.opk_missing = j.value("opk_missing", false); + if (j.contains("signed_prekey") && !j.at("signed_prekey").is_null()) { + v.signed_prekey = j.at("signed_prekey").get(); + } else { + v.signed_prekey.reset(); + } + if (j.contains("one_time_prekey") && !j.at("one_time_prekey").is_null()) { + v.one_time_prekey = j.at("one_time_prekey").get(); + } else { + v.one_time_prekey.reset(); + } +} + +inline void to_json(nlohmann::json& j, const ResolvePrekeysResponse& v) { + j = nlohmann::json{ + {"username", v.username}, + {"peer_address", v.peer_address}, + {"devices", v.devices}, + }; +} + +inline void from_json(const nlohmann::json& j, ResolvePrekeysResponse& v) { + v.username = j.value("username", ""); + v.peer_address = j.value("peer_address", ""); + if (j.contains("devices")) { + j.at("devices").get_to(v.devices); + } else { + v.devices.clear(); + } +} + +inline void to_json(nlohmann::json& j, const PresenceSetRequest& v) { + j = nlohmann::json{ + {"status", v.status}, + }; +} + +inline void from_json(const nlohmann::json& j, PresenceSetRequest& v) { + v.status = j.value("status", ""); +} + +inline void to_json(nlohmann::json& j, const PresenceSetResponse& v) { + j = nlohmann::json{ + {"status", v.status}, + }; +} + +inline void from_json(const nlohmann::json& j, PresenceSetResponse& v) { + v.status = j.value("status", ""); +} + +inline void to_json(nlohmann::json& j, const PresenceResolveRequest& v) { + j = nlohmann::json{ + {"peer_addresses", v.peer_addresses}, + }; +} + +inline void from_json(const nlohmann::json& j, PresenceResolveRequest& v) { + v.peer_addresses = j.value("peer_addresses", std::vector{}); +} + +inline void to_json(nlohmann::json& j, const PresencePeerOut& v) { + j = nlohmann::json{ + {"peer_address", v.peer_address}, + {"status", v.status}, + }; +} + +inline void from_json(const nlohmann::json& j, PresencePeerOut& v) { + v.peer_address = j.value("peer_address", ""); + v.status = j.value("status", "offline"); +} + +inline void to_json(nlohmann::json& j, const PresenceResolveResponse& v) { + j = nlohmann::json{ + {"self_status", v.self_status}, + {"peers", v.peers}, + }; +} + +inline void from_json(const nlohmann::json& j, PresenceResolveResponse& v) { + v.self_status = j.value("self_status", "offline"); + v.peers = j.value("peers", std::vector{}); +} + inline void from_json(const nlohmann::json& j, WsEventMessageNew& v) { + v.copy_id = j.value("copy_id", ""); j.at("message").get_to(v.message); } @@ -415,6 +1155,82 @@ inline void from_json(const nlohmann::json& j, VoiceAudioChunk& v) { j.at("pcm_b64").get_to(v.pcm_b64); } +inline void to_json(nlohmann::json& j, const VoiceCallWebRtcOffer& v) { + j = nlohmann::json{ + {"call_id", v.call_id}, + {"sdp", v.sdp}, + {"call_schema_version", v.call_schema_version}, + {"call_mode", v.call_mode}, + {"max_participants", v.max_participants}, + {"target_user_address", v.target_user_address.empty() ? nlohmann::json(nullptr) + : nlohmann::json(v.target_user_address)}, + {"source_user_address", v.source_user_address.empty() ? nlohmann::json(nullptr) + : nlohmann::json(v.source_user_address)}, + }; +} + +inline void from_json(const nlohmann::json& j, VoiceCallWebRtcOffer& v) { + v.call_id = j.value("call_id", ""); + v.sdp = j.value("sdp", ""); + v.call_schema_version = j.value("call_schema_version", 1); + v.call_mode = j.value("call_mode", "webrtc"); + v.max_participants = j.value("max_participants", 2); + v.target_user_address = j.value("target_user_address", ""); + v.source_user_address = j.value("source_user_address", ""); +} + +inline void to_json(nlohmann::json& j, const VoiceCallWebRtcAnswer& v) { + j = nlohmann::json{ + {"call_id", v.call_id}, + {"sdp", v.sdp}, + {"call_schema_version", v.call_schema_version}, + {"call_mode", v.call_mode}, + {"max_participants", v.max_participants}, + {"target_user_address", v.target_user_address.empty() ? nlohmann::json(nullptr) + : nlohmann::json(v.target_user_address)}, + {"source_user_address", v.source_user_address.empty() ? nlohmann::json(nullptr) + : nlohmann::json(v.source_user_address)}, + }; +} + +inline void from_json(const nlohmann::json& j, VoiceCallWebRtcAnswer& v) { + v.call_id = j.value("call_id", ""); + v.sdp = j.value("sdp", ""); + v.call_schema_version = j.value("call_schema_version", 1); + v.call_mode = j.value("call_mode", "webrtc"); + v.max_participants = j.value("max_participants", 2); + v.target_user_address = j.value("target_user_address", ""); + v.source_user_address = j.value("source_user_address", ""); +} + +inline void to_json(nlohmann::json& j, const VoiceCallWebRtcIce& v) { + j = nlohmann::json{ + {"call_id", v.call_id}, + {"candidate", v.candidate}, + {"sdp_mid", v.sdp_mid.empty() ? nlohmann::json(nullptr) : nlohmann::json(v.sdp_mid)}, + {"sdp_mline_index", v.sdp_mline_index < 0 ? nlohmann::json(nullptr) : nlohmann::json(v.sdp_mline_index)}, + {"call_schema_version", v.call_schema_version}, + {"call_mode", v.call_mode}, + {"max_participants", v.max_participants}, + {"target_user_address", v.target_user_address.empty() ? nlohmann::json(nullptr) + : nlohmann::json(v.target_user_address)}, + {"source_user_address", v.source_user_address.empty() ? nlohmann::json(nullptr) + : nlohmann::json(v.source_user_address)}, + }; +} + +inline void from_json(const nlohmann::json& j, VoiceCallWebRtcIce& v) { + v.call_id = j.value("call_id", ""); + v.candidate = j.value("candidate", ""); + v.sdp_mid = j.value("sdp_mid", ""); + v.sdp_mline_index = j.value("sdp_mline_index", -1); + v.call_schema_version = j.value("call_schema_version", 1); + v.call_mode = j.value("call_mode", "webrtc"); + v.max_participants = j.value("max_participants", 2); + v.target_user_address = j.value("target_user_address", ""); + v.source_user_address = j.value("source_user_address", ""); +} + inline void from_json(const nlohmann::json& j, WsEventCallIncoming& v) { j.at("call_id").get_to(v.call_id); j.at("conversation_id").get_to(v.conversation_id); @@ -434,6 +1250,10 @@ inline void from_json(const nlohmann::json& j, WsEventCallAccepted& v) { j.at("conversation_id").get_to(v.conversation_id); v.peer_user_id = j.value("peer_user_id", ""); v.peer_user_address = j.value("peer_user_address", ""); + v.call_schema_version = j.value("call_schema_version", 1); + v.call_mode = j.value("call_mode", ""); + v.max_participants = j.value("max_participants", 2); + v.ice_servers = j.value("ice_servers", nlohmann::json::array()); } inline void from_json(const nlohmann::json& j, WsEventCallRejected& v) { @@ -453,6 +1273,21 @@ inline void from_json(const nlohmann::json& j, WsEventCallEnded& v) { v.by_user_id = j.value("by_user_id", ""); } +inline void from_json(const nlohmann::json& j, WsEventCallGroupParticipant& v) { + v.member_address = j.value("member_address", ""); + v.state = j.value("state", ""); +} + +inline void from_json(const nlohmann::json& j, WsEventCallGroupState& v) { + v.call_id = j.value("call_id", ""); + v.conversation_id = j.value("conversation_id", ""); + v.group_uid = j.value("group_uid", ""); + v.state = j.value("state", ""); + v.call_mode = j.value("call_mode", "webrtc"); + v.max_participants = j.value("max_participants", 2); + v.participants = j.value("participants", std::vector{}); +} + inline void from_json(const nlohmann::json& j, WsEventCallAudio& v) { j.at("call_id").get_to(v.call_id); v.from_user_id = j.value("from_user_id", ""); @@ -466,4 +1301,47 @@ inline void from_json(const nlohmann::json& j, WsEventCallError& v) { v.detail = j.value("detail", ""); } +inline void from_json(const nlohmann::json& j, WsEventCallWebRtcOffer& v) { + v.call_id = j.value("call_id", ""); + v.sdp = j.value("sdp", ""); + v.call_schema_version = j.value("call_schema_version", 1); + v.call_mode = j.value("call_mode", "webrtc"); + v.max_participants = j.value("max_participants", 2); + v.from_user_address = j.value("from_user_address", j.value("source_user_address", "")); + v.source_user_address = j.value("source_user_address", v.from_user_address); + v.target_user_address = j.value("target_user_address", ""); +} + +inline void from_json(const nlohmann::json& j, WsEventCallWebRtcAnswer& v) { + v.call_id = j.value("call_id", ""); + v.sdp = j.value("sdp", ""); + v.call_schema_version = j.value("call_schema_version", 1); + v.call_mode = j.value("call_mode", "webrtc"); + v.max_participants = j.value("max_participants", 2); + v.from_user_address = j.value("from_user_address", j.value("source_user_address", "")); + v.source_user_address = j.value("source_user_address", v.from_user_address); + v.target_user_address = j.value("target_user_address", ""); +} + +inline void from_json(const nlohmann::json& j, WsEventCallWebRtcIce& v) { + v.call_id = j.value("call_id", ""); + v.candidate = j.value("candidate", ""); + v.sdp_mid = j.value("sdp_mid", ""); + v.sdp_mline_index = j.value("sdp_mline_index", -1); + v.call_schema_version = j.value("call_schema_version", 1); + v.call_mode = j.value("call_mode", "webrtc"); + v.max_participants = j.value("max_participants", 2); + v.from_user_address = j.value("from_user_address", j.value("source_user_address", "")); + v.source_user_address = j.value("source_user_address", v.from_user_address); + v.target_user_address = j.value("target_user_address", ""); +} + +inline void from_json(const nlohmann::json& j, WsEventGroupRenamed& v) { + v.conversation_id = j.value("conversation_id", ""); + v.group_uid = j.value("group_uid", ""); + v.group_name = j.value("group_name", ""); + v.actor_address = j.value("actor_address", ""); + v.event_seq = j.value("event_seq", 0); +} + } // namespace blackwire diff --git a/client-cpp-gui/include/blackwire/models/view_models.hpp b/client-cpp-gui/include/blackwire/models/view_models.hpp index 175c8e8..649fd91 100644 --- a/client-cpp-gui/include/blackwire/models/view_models.hpp +++ b/client-cpp-gui/include/blackwire/models/view_models.hpp @@ -1,5 +1,7 @@ #pragma once +#include + #include namespace blackwire { @@ -9,6 +11,24 @@ struct ConversationListItemView { QString title; QString subtitle; QString last_activity_at; + QString status = "offline"; + QString conversation_type = "direct"; + bool can_manage_members = false; + QString peer_address; + QString group_name; + int member_count = 0; +}; + +struct GroupInviteCandidateView { + QString peer_address; + QString title; + QString subtitle; + QString status = "offline"; +}; + +struct GroupInvitePickerView { + std::vector candidates; + int remaining_slots = 0; }; struct AudioDeviceOptionView { @@ -16,6 +36,12 @@ struct AudioDeviceOptionView { QString name; }; +struct CallParticipantView { + QString user_address; + QString label; + bool self = false; +}; + struct CallStateView { QString state = "idle"; QString call_id; @@ -23,6 +49,7 @@ struct CallStateView { QString peer_user_id; bool muted = false; QString reason; + std::vector participants; }; struct ThreadMessageView { diff --git a/client-cpp-gui/include/blackwire/storage/client_state.hpp b/client-cpp-gui/include/blackwire/storage/client_state.hpp index 05c269d..4e1dad8 100644 --- a/client-cpp-gui/include/blackwire/storage/client_state.hpp +++ b/client-cpp-gui/include/blackwire/storage/client_state.hpp @@ -15,9 +15,11 @@ struct LocalMessage { std::string id; std::string conversation_id; std::string sender_user_id; + std::string sender_address; std::string created_at; std::string rendered_text; std::string plaintext; + std::string plaintext_cache_b64; }; struct ConversationMeta { @@ -34,6 +36,7 @@ struct AudioPreferences { struct SocialPreferences { bool accept_messages_from_strangers = true; + std::string presence_status = "active"; }; struct ClientState { @@ -46,9 +49,12 @@ struct ClientState { std::map conversation_meta; std::map> local_messages; std::set seen_message_ids; + std::map pinned_sender_sign_keys_by_device_uid; + std::map last_verified_chain_hash_by_conversation_sender; AudioPreferences audio_preferences; SocialPreferences social_preferences; std::set blocked_conversation_ids; + std::set dismissed_conversation_ids; bool MarkMessageSeen(const std::string& message_id) { const auto inserted = seen_message_ids.insert(message_id); @@ -60,8 +66,12 @@ inline void to_json(nlohmann::json& j, const LocalMessage& v) { j = nlohmann::json{{"id", v.id}, {"conversation_id", v.conversation_id}, {"sender_user_id", v.sender_user_id}, + {"sender_address", v.sender_address.empty() ? nlohmann::json(nullptr) + : nlohmann::json(v.sender_address)}, {"created_at", v.created_at}, - {"rendered_text", v.rendered_text}}; + {"rendered_text", v.rendered_text}, + {"plaintext_cache_b64", v.plaintext_cache_b64.empty() ? nlohmann::json(nullptr) + : nlohmann::json(v.plaintext_cache_b64)}}; } inline void to_json(nlohmann::json& j, const ConversationMeta& v) { @@ -77,11 +87,17 @@ inline void from_json(const nlohmann::json& j, LocalMessage& v) { j.at("id").get_to(v.id); j.at("conversation_id").get_to(v.conversation_id); j.at("sender_user_id").get_to(v.sender_user_id); + v.sender_address = j.value("sender_address", ""); j.at("created_at").get_to(v.created_at); // Scrub any historical plaintext payload that may have been persisted. v.rendered_text = "[encrypted message]"; // Never deserialize historical plaintext from disk into runtime state. v.plaintext.clear(); + if (j.contains("plaintext_cache_b64") && !j.at("plaintext_cache_b64").is_null()) { + v.plaintext_cache_b64 = j.value("plaintext_cache_b64", ""); + } else { + v.plaintext_cache_b64.clear(); + } } inline void from_json(const nlohmann::json& j, ConversationMeta& v) { @@ -111,11 +127,13 @@ inline void from_json(const nlohmann::json& j, AudioPreferences& v) { inline void to_json(nlohmann::json& j, const SocialPreferences& v) { j = nlohmann::json{ {"accept_messages_from_strangers", v.accept_messages_from_strangers}, + {"presence_status", v.presence_status}, }; } inline void from_json(const nlohmann::json& j, SocialPreferences& v) { v.accept_messages_from_strangers = j.value("accept_messages_from_strangers", true); + v.presence_status = j.value("presence_status", "active"); } inline void to_json(nlohmann::json& j, const ClientState& v) { @@ -126,9 +144,12 @@ inline void to_json(nlohmann::json& j, const ClientState& v) { {"conversation_meta", v.conversation_meta}, {"local_messages", v.local_messages}, {"seen_message_ids", v.seen_message_ids}, + {"pinned_sender_sign_keys_by_device_uid", v.pinned_sender_sign_keys_by_device_uid}, + {"last_verified_chain_hash_by_conversation_sender", v.last_verified_chain_hash_by_conversation_sender}, {"audio_preferences", v.audio_preferences}, {"social_preferences", v.social_preferences}, - {"blocked_conversation_ids", v.blocked_conversation_ids}}; + {"blocked_conversation_ids", v.blocked_conversation_ids}, + {"dismissed_conversation_ids", v.dismissed_conversation_ids}}; if (v.has_user) { j["user"] = v.user; } @@ -161,6 +182,12 @@ inline void from_json(const nlohmann::json& j, ClientState& v) { if (j.contains("seen_message_ids")) { j.at("seen_message_ids").get_to(v.seen_message_ids); } + if (j.contains("pinned_sender_sign_keys_by_device_uid")) { + j.at("pinned_sender_sign_keys_by_device_uid").get_to(v.pinned_sender_sign_keys_by_device_uid); + } + if (j.contains("last_verified_chain_hash_by_conversation_sender")) { + j.at("last_verified_chain_hash_by_conversation_sender").get_to(v.last_verified_chain_hash_by_conversation_sender); + } if (j.contains("audio_preferences")) { j.at("audio_preferences").get_to(v.audio_preferences); } @@ -170,6 +197,9 @@ inline void from_json(const nlohmann::json& j, ClientState& v) { if (j.contains("blocked_conversation_ids")) { j.at("blocked_conversation_ids").get_to(v.blocked_conversation_ids); } + if (j.contains("dismissed_conversation_ids")) { + j.at("dismissed_conversation_ids").get_to(v.dismissed_conversation_ids); + } } } // namespace blackwire diff --git a/client-cpp-gui/include/blackwire/ui/chat_widget.hpp b/client-cpp-gui/include/blackwire/ui/chat_widget.hpp index 7b65016..77ee11a 100644 --- a/client-cpp-gui/include/blackwire/ui/chat_widget.hpp +++ b/client-cpp-gui/include/blackwire/ui/chat_widget.hpp @@ -11,8 +11,10 @@ class QLineEdit; class QListWidget; class QPlainTextEdit; class QPushButton; +class QComboBox; class QStackedWidget; class QTimer; +class QHBoxLayout; class QWidget; namespace blackwire { @@ -31,9 +33,11 @@ class ChatWidget final : public QWidget { void SetConversationList(const std::vector& conversations); void SetThreadMessages(const std::vector& messages); void AppendThreadMessage(const ThreadMessageView& message); + void ShowGroupInviteDialog(const GroupInvitePickerView& picker); void SetConnectionStatus(const QString& status); void SetIdentity(const QString& user_address); void SetCallState(const CallStateView& state); + void SetUserStatus(const QString& status); void ShowBanner(const QString& text, const QString& severity); void ClearCompose(); void SetSendEnabled(bool enabled); @@ -42,40 +46,70 @@ class ChatWidget final : public QWidget { signals: void NewConversationRequested(); void ConversationSelected(const QString& conversation_id); + void ConversationRemoveRequested( + const QString& conversation_id, + const QString& conversation_type, + bool can_manage_members, + const QString& title); void SendMessageRequested(); + void SendFileRequested(const QString& file_path); void SettingsRequested(); void StartVoiceCallRequested(); void AcceptVoiceCallRequested(); void RejectVoiceCallRequested(); void EndVoiceCallRequested(); void CallMuteToggled(bool muted); + void UserStatusChanged(const QString& status); + void CreateGroupFromDmRequested(); + void GroupInviteDialogRequested(); + void InviteGroupMembersRequested(const std::vector& addresses); + void GroupRenameRequested(const QString& name); private: bool eventFilter(QObject* watched, QEvent* event) override; - QWidget* CreateConversationItemWidget(const ConversationListItemView& item, bool selected) const; + QWidget* CreateConversationItemWidget(const ConversationListItemView& item, bool selected, bool removable); QWidget* CreateThreadMessageWidget(const ThreadMessageView& message) const; bool IsTimelineNearBottom() const; void ScrollTimelineToBottom(); void SetTimelineHasMessages(bool has_messages); void UpdateThreadHeader(); void RefreshConversationSelectionStyles(); + void RefreshListSelectionStyles(QListWidget* list); + void SyncContactSelection(const QString& conversation_id); + void SetContactsMode(bool enabled); + void UpdateContactsButtonState(); + void UpdatePresenceIndicator(const QString& status); void UpdateCallControls(); + QString SelectedConversationType() const; + bool SelectedConversationCanManageMembers() const; QLineEdit* peer_input_; + QPushButton* contacts_button_; QPushButton* new_chat_button_; QPushButton* settings_button_; QListWidget* conversations_list_; + QListWidget* contacts_panel_list_; QListWidget* messages_list_; + QStackedWidget* content_stack_; + QWidget* chat_panel_; + QWidget* contacts_panel_; QStackedWidget* timeline_stack_; QPlainTextEdit* compose_input_; + QPushButton* attach_button_; QPushButton* send_button_; + QComboBox* presence_combo_; + QLabel* presence_indicator_; QLabel* status_label_; QLabel* call_status_label_; QWidget* call_panel_; QLabel* call_panel_avatar_; QLabel* call_panel_title_; QLabel* call_panel_subtitle_; + QWidget* call_participants_panel_; + QHBoxLayout* call_participants_layout_; QPushButton* call_button_; + QPushButton* group_button_; + QPushButton* invite_button_; QPushButton* accept_call_button_; QPushButton* decline_call_button_; QPushButton* mute_call_button_; @@ -83,11 +117,12 @@ class ChatWidget final : public QWidget { QLabel* identity_label_; QPushButton* copy_identity_button_; QLabel* banner_label_; - QLabel* thread_title_label_; + QLineEdit* thread_title_label_; QLabel* empty_state_label_; QTimer* banner_timer_; QString identity_value_; CallStateView call_state_; + bool contacts_mode_active_ = true; }; } // namespace blackwire diff --git a/client-cpp-gui/include/blackwire/ui/main_window.hpp b/client-cpp-gui/include/blackwire/ui/main_window.hpp index da987cb..a6f8d67 100644 --- a/client-cpp-gui/include/blackwire/ui/main_window.hpp +++ b/client-cpp-gui/include/blackwire/ui/main_window.hpp @@ -1,6 +1,7 @@ #pragma once #include +#include namespace blackwire { @@ -24,6 +25,7 @@ class MainWindow final : public QMainWindow { ChatWidget* chat_widget_; SettingsDialog* settings_dialog_; QWidget* stacked_container_; + QString last_integrity_warning_; }; } // namespace blackwire diff --git a/client-cpp-gui/include/blackwire/ui/settings_dialog.hpp b/client-cpp-gui/include/blackwire/ui/settings_dialog.hpp index fc76109..43428d6 100644 --- a/client-cpp-gui/include/blackwire/ui/settings_dialog.hpp +++ b/client-cpp-gui/include/blackwire/ui/settings_dialog.hpp @@ -23,6 +23,7 @@ class QStackedWidget; namespace blackwire { struct AudioDeviceOptionView; +struct DeviceOut; class SettingsDialog final : public QDialog { Q_OBJECT @@ -35,11 +36,13 @@ class SettingsDialog final : public QDialog { void SetDeviceInfo(const QString& label, const QString& device_id); void SetConnectionStatus(const QString& status); void SetDiagnostics(const QString& diagnostics); + void SetAccountDevices(const std::vector& devices, const QString& current_device_uid); void SetAudioDevices( const std::vector& input_devices, const std::vector& output_devices); void SetSelectedAudioDevices(const QString& input_device_id, const QString& output_device_id); void SetAcceptMessagesFromStrangers(bool enabled); + void SetIntegrityWarning(const QString& warning); bool AcceptMessagesFromStrangers() const; protected: @@ -49,6 +52,7 @@ class SettingsDialog final : public QDialog { void LogoutRequested(); void ResetStateRequested(); void CopyDiagnosticsRequested(); + void RevokeDeviceRequested(const QString& device_uid); void ApplyAudioDevicesRequested(const QString& input_device_id, const QString& output_device_id); void AcceptMessagesFromStrangersChanged(bool enabled); @@ -68,6 +72,8 @@ class SettingsDialog final : public QDialog { QString identity_; QString device_id_; + QString current_device_uid_; + QString selected_account_device_uid_; QString diagnostics_; QListWidget* tabs_list_ = nullptr; @@ -78,6 +84,8 @@ class SettingsDialog final : public QDialog { QLabel* device_label_value_ = nullptr; QLabel* device_id_value_ = nullptr; QLabel* connection_status_value_ = nullptr; + QListWidget* account_devices_list_ = nullptr; + QPushButton* revoke_device_button_ = nullptr; QComboBox* input_device_combo_ = nullptr; QComboBox* output_device_combo_ = nullptr; QPushButton* apply_audio_button_ = nullptr; @@ -88,6 +96,7 @@ class SettingsDialog final : public QDialog { QPushButton* copy_id_button_ = nullptr; QPushButton* copy_device_button_ = nullptr; QPushButton* copy_diagnostics_button_ = nullptr; + QLabel* integrity_warning_value_ = nullptr; QPushButton* logout_button_ = nullptr; QPushButton* reset_button_ = nullptr; diff --git a/client-cpp-gui/include/blackwire/ws/qt_ws_client.hpp b/client-cpp-gui/include/blackwire/ws/qt_ws_client.hpp index 4450729..97d8e4d 100644 --- a/client-cpp-gui/include/blackwire/ws/qt_ws_client.hpp +++ b/client-cpp-gui/include/blackwire/ws/qt_ws_client.hpp @@ -17,6 +17,7 @@ class QtWsClient final : public QObject, public IWsClient { void SetHandlers( MessageHandler on_message, CallIncomingHandler on_call_incoming, + CallGroupStateHandler on_call_group_state, CallRingingHandler on_call_ringing, CallAcceptedHandler on_call_accepted, CallRejectedHandler on_call_rejected, @@ -24,6 +25,10 @@ class QtWsClient final : public QObject, public IWsClient { CallEndedHandler on_call_ended, CallAudioHandler on_call_audio, CallErrorHandler on_call_error, + CallWebRtcOfferHandler on_call_webrtc_offer, + CallWebRtcAnswerHandler on_call_webrtc_answer, + CallWebRtcIceHandler on_call_webrtc_ice, + GroupRenamedHandler on_group_renamed, ErrorHandler on_error, StatusHandler on_status) override; void Connect(const std::string& base_url, const std::string& access_token) override; @@ -34,6 +39,9 @@ class QtWsClient final : public QObject, public IWsClient { void SendCallReject(const VoiceCallReject& reject) override; void SendCallEnd(const VoiceCallEnd& end) override; void SendCallAudioChunk(const VoiceAudioChunk& chunk) override; + void SendCallWebRtcOffer(const VoiceCallWebRtcOffer& offer) override; + void SendCallWebRtcAnswer(const VoiceCallWebRtcAnswer& answer) override; + void SendCallWebRtcIce(const VoiceCallWebRtcIce& ice) override; private: void ScheduleReconnect(); @@ -45,6 +53,7 @@ class QtWsClient final : public QObject, public IWsClient { MessageHandler on_message_; CallIncomingHandler on_call_incoming_; + CallGroupStateHandler on_call_group_state_; CallRingingHandler on_call_ringing_; CallAcceptedHandler on_call_accepted_; CallRejectedHandler on_call_rejected_; @@ -52,6 +61,10 @@ class QtWsClient final : public QObject, public IWsClient { CallEndedHandler on_call_ended_; CallAudioHandler on_call_audio_; CallErrorHandler on_call_error_; + CallWebRtcOfferHandler on_call_webrtc_offer_; + CallWebRtcAnswerHandler on_call_webrtc_answer_; + CallWebRtcIceHandler on_call_webrtc_ice_; + GroupRenamedHandler on_group_renamed_; ErrorHandler on_error_; StatusHandler on_status_; diff --git a/client-cpp-gui/src/api/qt_api_client.cpp b/client-cpp-gui/src/api/qt_api_client.cpp index 98b3dae..a65199e 100644 --- a/client-cpp-gui/src/api/qt_api_client.cpp +++ b/client-cpp-gui/src/api/qt_api_client.cpp @@ -26,7 +26,24 @@ QString JoinUrl(const std::string& base_url, const QString& path) { if (base.endsWith('/')) { base.chop(1); } - return base + "/api/v1" + path; + return base + "/api/v2" + path; +} + +int DefaultPortForScheme(const QString& scheme) { + if (scheme == "https" || scheme == "wss") { + return 443; + } + return 80; +} + +QString AuthorityFromBaseUrl(const std::string& base_url) { + const QUrl url(QString::fromStdString(base_url).trimmed()); + if (!url.isValid() || url.host().trimmed().isEmpty()) { + return "local.invalid"; + } + const QString host = url.host().trimmed().toLower(); + const int port = url.port() > 0 ? url.port() : DefaultPortForScheme(url.scheme().trimmed().toLower()); + return QString("%1:%2").arg(host).arg(port); } bool IsOnionHost(const QString& host) { @@ -126,16 +143,60 @@ UserOut QtApiClient::Me(const std::string& base_url, const std::string& access_t return json.get(); } -DeviceOut QtApiClient::RegisterDevice( +AuthResponse QtApiClient::RegisterDevice( const std::string& base_url, - const std::string& access_token, + const std::string& bootstrap_token, const DeviceRegisterRequest& request) { const nlohmann::json body = request; const auto json = RequestJson( "POST", JoinUrl(base_url, "/devices/register"), - QString::fromStdString(access_token), + QString::fromStdString(bootstrap_token), + &body); + return json.get(); +} + +AuthResponse QtApiClient::BindDevice( + const std::string& base_url, + const std::string& bootstrap_token, + const std::string& device_uid, + const std::string& nonce, + long long timestamp_ms, + const std::string& proof_signature_b64) { + nlohmann::json body = { + {"device_uid", device_uid}, + {"nonce", nonce}, + {"timestamp_ms", timestamp_ms}, + {"proof_signature_b64", proof_signature_b64}, + }; + const auto json = RequestJson( + "POST", + JoinUrl(base_url, "/auth/bind-device"), + QString::fromStdString(bootstrap_token), &body); + return json.get(); +} + +std::vector QtApiClient::ListDevices( + const std::string& base_url, + const std::string& access_token) { + const auto json = RequestJson( + "GET", + JoinUrl(base_url, "/devices"), + QString::fromStdString(access_token), + nullptr); + return json.get>(); +} + +DeviceOut QtApiClient::RevokeDevice( + const std::string& base_url, + const std::string& access_token, + const std::string& device_uid) { + const auto json = RequestJson( + "POST", + JoinUrl(base_url, QString("/devices/%1/revoke").arg(QString::fromStdString(device_uid))), + QString::fromStdString(access_token), + nullptr); return json.get(); } @@ -146,13 +207,17 @@ UserDeviceLookup QtApiClient::GetUserDevice( const QString peer = QString::fromStdString(peer_address).trimmed(); QString endpoint; if (peer.contains('@')) { - QUrl url(JoinUrl(base_url, "/users/resolve-device")); + QUrl url(JoinUrl(base_url, "/users/resolve-devices")); QUrlQuery query; query.addQueryItem("peer_address", peer); url.setQuery(query); endpoint = url.toString(); } else { - endpoint = JoinUrl(base_url, QString("/users/%1/device").arg(peer)); + QUrl url(JoinUrl(base_url, "/users/resolve-devices")); + QUrlQuery query; + query.addQueryItem("peer_address", QString("%1@%2").arg(peer, AuthorityFromBaseUrl(base_url))); + url.setQuery(query); + endpoint = url.toString(); } const auto json = RequestJson( @@ -163,6 +228,63 @@ UserDeviceLookup QtApiClient::GetUserDevice( return json.get(); } +ResolvePrekeysResponse QtApiClient::ResolvePrekeys( + const std::string& base_url, + const std::string& access_token, + const std::string& peer_address) { + const QString peer = QString::fromStdString(peer_address).trimmed(); + QUrl url(JoinUrl(base_url, "/users/resolve-prekeys")); + QUrlQuery query; + query.addQueryItem("peer_address", peer); + url.setQuery(query); + + const auto json = RequestJson( + "GET", + url.toString(), + QString::fromStdString(access_token), + nullptr); + return json.get(); +} + +PrekeyUploadResponse QtApiClient::UploadPrekeys( + const std::string& base_url, + const std::string& access_token, + const PrekeyUploadRequest& request) { + const nlohmann::json body = request; + const auto json = RequestJson( + "POST", + JoinUrl(base_url, "/keys/prekeys/upload"), + QString::fromStdString(access_token), + &body); + return json.get(); +} + +PresenceSetResponse QtApiClient::SetPresenceStatus( + const std::string& base_url, + const std::string& access_token, + const PresenceSetRequest& request) { + const nlohmann::json body = request; + const auto json = RequestJson( + "POST", + JoinUrl(base_url, "/presence/set"), + QString::fromStdString(access_token), + &body); + return json.get(); +} + +PresenceResolveResponse QtApiClient::ResolvePresence( + const std::string& base_url, + const std::string& access_token, + const PresenceResolveRequest& request) { + const nlohmann::json body = request; + const auto json = RequestJson( + "POST", + JoinUrl(base_url, "/presence/resolve"), + QString::fromStdString(access_token), + &body); + return json.get(); +} + ConversationOut QtApiClient::CreateDm( const std::string& base_url, const std::string& access_token, @@ -183,6 +305,19 @@ ConversationOut QtApiClient::CreateDm( return json.get(); } +ConversationOut QtApiClient::CreateGroup( + const std::string& base_url, + const std::string& access_token, + const CreateGroupConversationRequest& request) { + const nlohmann::json body = request; + const auto json = RequestJson( + "POST", + JoinUrl(base_url, "/conversations/group"), + QString::fromStdString(access_token), + &body); + return json.get(); +} + std::vector QtApiClient::ListConversations( const std::string& base_url, const std::string& access_token) { @@ -194,6 +329,82 @@ std::vector QtApiClient::ListConversations( return json.get>(); } +std::vector QtApiClient::ListConversationMembers( + const std::string& base_url, + const std::string& access_token, + const std::string& conversation_id) { + const auto json = RequestJson( + "GET", + JoinUrl(base_url, QString("/conversations/%1/members").arg(QString::fromStdString(conversation_id))), + QString::fromStdString(access_token), + nullptr); + return json.get>(); +} + +std::vector QtApiClient::InviteConversationMembers( + const std::string& base_url, + const std::string& access_token, + const std::string& conversation_id, + const GroupInviteRequest& request) { + const nlohmann::json body = request; + const auto json = RequestJson( + "POST", + JoinUrl(base_url, QString("/conversations/%1/members/invite").arg(QString::fromStdString(conversation_id))), + QString::fromStdString(access_token), + &body); + return json.get>(); +} + +ConversationOut QtApiClient::RenameConversationGroup( + const std::string& base_url, + const std::string& access_token, + const std::string& conversation_id, + const GroupRenameRequest& request) { + const nlohmann::json body = request; + const auto json = RequestJson( + "POST", + JoinUrl(base_url, QString("/conversations/%1/rename").arg(QString::fromStdString(conversation_id))), + QString::fromStdString(access_token), + &body); + return json.get(); +} + +ConversationMemberOut QtApiClient::AcceptConversationInvite( + const std::string& base_url, + const std::string& access_token, + const std::string& conversation_id) { + const auto json = RequestJson( + "POST", + JoinUrl(base_url, QString("/conversations/%1/invites/accept").arg(QString::fromStdString(conversation_id))), + QString::fromStdString(access_token), + nullptr); + return json.get(); +} + +ConversationMemberOut QtApiClient::LeaveConversationGroup( + const std::string& base_url, + const std::string& access_token, + const std::string& conversation_id) { + const auto json = RequestJson( + "POST", + JoinUrl(base_url, QString("/conversations/%1/leave").arg(QString::fromStdString(conversation_id))), + QString::fromStdString(access_token), + nullptr); + return json.get(); +} + +ConversationRecipientsOut QtApiClient::GetConversationRecipients( + const std::string& base_url, + const std::string& access_token, + const std::string& conversation_id) { + const auto json = RequestJson( + "GET", + JoinUrl(base_url, QString("/conversations/%1/recipients").arg(QString::fromStdString(conversation_id))), + QString::fromStdString(access_token), + nullptr); + return json.get(); +} + std::vector QtApiClient::ListMessages( const std::string& base_url, const std::string& access_token, @@ -207,7 +418,22 @@ std::vector QtApiClient::ListMessages( url.setQuery(query); const auto json = RequestJson("GET", url.toString(), QString::fromStdString(access_token), nullptr); - return json.get>(); + std::vector out; + if (!json.is_array()) { + return out; + } + for (const auto& item : json) { + if (item.contains("message")) { + auto message = item.at("message").get(); + if (item.contains("envelope_json")) { + message.envelope = item.at("envelope_json").get(); + } + out.push_back(message); + continue; + } + out.push_back(item.get()); + } + return out; } MessageSendResponse QtApiClient::SendMessage( diff --git a/client-cpp-gui/src/audio/qt_audio_call_engine.cpp b/client-cpp-gui/src/audio/qt_audio_call_engine.cpp index 44c59df..6ee9348 100644 --- a/client-cpp-gui/src/audio/qt_audio_call_engine.cpp +++ b/client-cpp-gui/src/audio/qt_audio_call_engine.cpp @@ -1,6 +1,8 @@ #include "blackwire/audio/qt_audio_call_engine.hpp" #include +#include +#include #include #include @@ -254,12 +256,43 @@ void QtAudioCallEngine::FlushPlayback() { return; } - QByteArray frame; + QByteArray frame(kFrameBytes, '\0'); if (playback_queue_.empty()) { - frame = QByteArray(kFrameBytes, '\0'); + // Keep writing silence when no remote audio is available. } else { - frame = playback_queue_.front(); - playback_queue_.pop_front(); + // Mix multiple queued frames into a single 20ms buffer so group calls + // do not build unbounded playback delay as participant count grows. + const int frames_to_mix = std::min(kMaxFramesMixedPerTick, static_cast(playback_queue_.size())); + std::array mixed{}; + + for (int i = 0; i < frames_to_mix; ++i) { + QByteArray chunk = playback_queue_.front(); + playback_queue_.pop_front(); + if (chunk.size() < kFrameBytes) { + chunk.append(QByteArray(kFrameBytes - chunk.size(), '\0')); + } else if (chunk.size() > kFrameBytes) { + chunk.truncate(kFrameBytes); + } + + const auto* input = reinterpret_cast(chunk.constData()); + for (int sample = 0; sample < kSamplesPerFrame; ++sample) { + mixed[sample] += static_cast(input[sample]); + } + } + + auto* output = reinterpret_cast(frame.data()); + const int divisor = std::max(1, frames_to_mix); + for (int sample = 0; sample < kSamplesPerFrame; ++sample) { + const int averaged = mixed[sample] / divisor; + output[sample] = static_cast(std::clamp( + averaged, + static_cast(std::numeric_limits::min()), + static_cast(std::numeric_limits::max()))); + } + + while (static_cast(playback_queue_.size()) > kPlaybackLatencyCapFrames) { + playback_queue_.pop_front(); + } } const qint64 written = sink_device_->write(frame); diff --git a/client-cpp-gui/src/controller/application_controller.cpp b/client-cpp-gui/src/controller/application_controller.cpp index f6facac..96ce9e5 100644 --- a/client-cpp-gui/src/controller/application_controller.cpp +++ b/client-cpp-gui/src/controller/application_controller.cpp @@ -2,13 +2,19 @@ #include #include +#include #include #include +#include +#include #include #include #include #include +#include +#include +#include #include #include #include @@ -24,11 +30,98 @@ namespace blackwire { namespace { constexpr int kMaxDiagnostics = 200; +constexpr qint64 kConservativeInlineAttachmentFallbackBytes = 1 * 1024 * 1024; +constexpr qint64 kSealedBoxOverheadBytes = 64; +const char* kFileMessagePrefix = "bwfile://v1:"; QString SanitizeDiagnosticText(QString text) { return SanitizeDiagnosticsText(text); } +bool EnvFlagEnabled(const char* name, bool default_value) { + const QString raw = qEnvironmentVariable(name).trimmed().toLower(); + if (raw.isEmpty()) { + return default_value; + } + return raw == "1" || raw == "true" || raw == "yes" || raw == "on"; +} + +QString NormalizePresenceValue(const QString& status) { + const QString normalized = status.trimmed().toLower(); + if (normalized == "active" || normalized == "inactive" || normalized == "offline" || normalized == "dnd") { + return normalized; + } + return "active"; +} + +QString FormatBytesHuman(qint64 bytes) { + if (bytes >= 1024LL * 1024LL) { + const double mb = static_cast(bytes) / (1024.0 * 1024.0); + return QString("%1 MB").arg(mb, 0, 'f', mb >= 100 ? 0 : 1); + } + if (bytes >= 1024LL) { + const double kb = static_cast(bytes) / 1024.0; + return QString("%1 KB").arg(kb, 0, 'f', kb >= 100 ? 0 : 1); + } + return QString("%1 B").arg(bytes); +} + +QString UsernameFromAddress(const QString& actor_address) { + const QString normalized = actor_address.trimmed().toLower(); + if (normalized.isEmpty()) { + return "A member"; + } + const int at = normalized.indexOf('@'); + if (at > 0) { + return normalized.left(at).trimmed(); + } + return normalized; +} + +std::string CanonicalMessageSignature( + const std::string& sender_address, + const std::string& sender_device_uid, + const std::string& recipient_address, + const std::string& recipient_device_uid, + const std::string& client_message_id, + long long sent_at_ms, + const std::string& sender_prev_hash, + const std::string& sender_chain_hash, + const std::string& ciphertext_hash, + const std::string& aad_hash) { + return sender_address + "\n" + sender_device_uid + "\n" + recipient_address + "\n" + recipient_device_uid + "\n" + + client_message_id + "\n" + std::to_string(sent_at_ms) + "\n" + sender_prev_hash + "\n" + sender_chain_hash + + "\n" + ciphertext_hash + "\n" + aad_hash; +} + +std::string AggregateChainHash( + ICryptoService& crypto, + const std::string& sender_prev_hash, + const std::string& client_message_id, + long long sent_at_ms, + const std::vector& hash_material) { + std::vector ordered = hash_material; + std::sort(ordered.begin(), ordered.end()); + std::string aggregate; + for (std::size_t i = 0; i < ordered.size(); ++i) { + if (i > 0) { + aggregate.append("|"); + } + aggregate.append(ordered[i]); + } + const std::string aggregate_hash = crypto.Sha256(aggregate); + return crypto.Sha256( + sender_prev_hash + "\n" + client_message_id + "\n" + std::to_string(sent_at_ms) + "\n" + aggregate_hash); +} + +std::string CanonicalSignedPrekeyString( + const std::string& device_uid, + int key_id, + const std::string& pub_x25519_b64, + const std::string& expires_at) { + return "SIGNED_PREKEY\n" + device_uid + "\n" + std::to_string(key_id) + "\n" + pub_x25519_b64 + "\n" + expires_at; +} + } // namespace ApplicationController::ApplicationController( @@ -50,21 +143,100 @@ ApplicationController::ApplicationController( profile_name_(profile_name.trimmed().isEmpty() ? "default" : profile_name.trimmed().toLower().toStdString()) { qRegisterMetaType("ConversationListItemView"); qRegisterMetaType>("std::vector"); + qRegisterMetaType("blackwire::DeviceOut"); + qRegisterMetaType>("std::vector"); qRegisterMetaType("blackwire::AudioDeviceOptionView"); qRegisterMetaType>("std::vector"); qRegisterMetaType("blackwire::CallStateView"); qRegisterMetaType("blackwire::ThreadMessageView"); qRegisterMetaType>("std::vector"); + presence_poll_timer_ = new QTimer(this); + presence_poll_timer_->setInterval(5000); + connect(presence_poll_timer_, &QTimer::timeout, this, [this]() { + if (!state_.has_user || !state_.has_device) { + return; + } + LoadConversations(); + }); + presence_poll_timer_->start(); + ws_client_.SetHandlers( [this](const WsEventMessageNew& event) { try { const auto& msg = event.message; - if (!state_.MarkMessageSeen(msg.id)) { - ws_client_.SendAck(msg.id); + const std::string ack_id = event.copy_id.empty() ? msg.id : event.copy_id; + const std::string dedupe_id = ack_id.empty() ? msg.id : ack_id; + if (!state_.MarkMessageSeen(dedupe_id)) { + ws_client_.SendAck(ack_id.empty() ? msg.id : ack_id); return; } + if (!msg.sender_device_uid.empty() && !msg.sender_device_pubkey.empty() && + !msg.envelope.signature_b64.empty()) { + const auto pin_it = state_.pinned_sender_sign_keys_by_device_uid.find(msg.sender_device_uid); + if (pin_it != state_.pinned_sender_sign_keys_by_device_uid.end() && + pin_it->second != msg.sender_device_pubkey) { + const QString line = QString("Integrity warning: sender key changed for device %1") + .arg(QString::fromStdString(msg.sender_device_uid)); + RecordDiagnostic(line); + emit IntegrityWarningOccurred(line); + return; + } + state_.pinned_sender_sign_keys_by_device_uid[msg.sender_device_uid] = msg.sender_device_pubkey; + + const QByteArray ciphertext_bytes = QByteArray::fromBase64( + QByteArray::fromStdString(msg.envelope.ciphertext_b64)); + const QByteArray aad_bytes = QByteArray::fromBase64( + QByteArray::fromStdString(msg.envelope.aad_b64)); + const std::string ciphertext_hash = crypto_.Sha256(ciphertext_bytes.toStdString()); + const std::string aad_hash = crypto_.Sha256(aad_bytes.toStdString()); + const std::string recipient_address = msg.envelope.recipient_user_address.empty() + ? QString("%1@%2") + .arg( + QString::fromStdString(state_.user.username).trimmed().toLower(), + QString::fromStdString( + state_.user.home_server_onion.empty() + ? ServerAuthority().toStdString() + : state_.user.home_server_onion)) + .toStdString() + : msg.envelope.recipient_user_address; + const std::string canonical = CanonicalMessageSignature( + msg.sender_address, + msg.sender_device_uid, + recipient_address, + msg.envelope.recipient_device_uid.empty() ? msg.envelope.recipient_device_id + : msg.envelope.recipient_device_uid, + msg.client_message_id, + msg.sent_at_ms, + msg.sender_prev_hash, + msg.sender_chain_hash, + ciphertext_hash, + aad_hash); + if (!crypto_.VerifyDetached( + msg.sender_device_pubkey, + canonical, + msg.envelope.signature_b64)) { + const QString line = QString("Integrity warning: invalid sender signature from %1") + .arg(QString::fromStdString(msg.sender_address)); + RecordDiagnostic(line); + emit IntegrityWarningOccurred(line); + return; + } + + const std::string chain_key = msg.conversation_id + "|" + msg.sender_device_uid; + const auto chain_it = state_.last_verified_chain_hash_by_conversation_sender.find(chain_key); + const std::string expected_prev = chain_it == state_.last_verified_chain_hash_by_conversation_sender.end() + ? std::string() + : chain_it->second; + if (msg.sender_prev_hash != expected_prev) { + const QString line = QString("Integrity warning: missing/reordered message detected"); + RecordDiagnostic(line); + emit IntegrityWarningOccurred(line); + } + state_.last_verified_chain_hash_by_conversation_sender[chain_key] = msg.sender_chain_hash; + } + std::string plaintext = "[unable to decrypt]"; try { std::string error; @@ -76,15 +248,19 @@ ApplicationController::ApplicationController( plaintext = "[unable to decrypt]"; } - ws_client_.SendAck(msg.id); + ws_client_.SendAck(ack_id.empty() ? msg.id : ack_id); LocalMessage local; local.id = msg.id; local.conversation_id = msg.conversation_id; local.sender_user_id = msg.sender_user_id; + local.sender_address = msg.sender_address; local.created_at = msg.created_at; local.rendered_text = RenderMessage(msg, plaintext).toStdString(); local.plaintext = plaintext; + if (state_.dismissed_conversation_ids.contains(msg.conversation_id)) { + state_.dismissed_conversation_ids.erase(msg.conversation_id); + } if (state_.blocked_conversation_ids.contains(msg.conversation_id)) { PersistState(); return; @@ -171,6 +347,10 @@ ApplicationController::ApplicationController( }, [this](const WsEventCallIncoming& event) { if (!IsCallState("idle")) { + if (!call_state_.call_id.trimmed().isEmpty() && + call_state_.call_id == QString::fromStdString(event.call_id)) { + return; + } VoiceCallReject reject; reject.call_id = event.call_id; reject.reason = "busy"; @@ -178,6 +358,8 @@ ApplicationController::ApplicationController( return; } + pending_outgoing_end_request_ = false; + call_initiated_locally_ = false; TransitionCallState( "incoming_ringing", QString::fromStdString(event.call_id), @@ -187,8 +369,224 @@ ApplicationController::ApplicationController( QString()); emit IncomingCallReceived(call_state_); }, + [this](const WsEventCallGroupState& event) { + const QString state = QString::fromStdString(event.state).trimmed().toLower(); + const QString call_id = QString::fromStdString(event.call_id); + const QString conversation_id = QString::fromStdString(event.conversation_id); + QString title = FriendlyConversationTitle(event.conversation_id); + if (title.trimmed().isEmpty()) { + title = "Group"; + } + const QString local_server = QString::fromStdString( + state_.user.home_server_onion.empty() + ? ServerAuthority().toStdString() + : state_.user.home_server_onion) + .trimmed() + .toLower(); + const QString self_address = QString("%1@%2") + .arg(QString::fromStdString(state_.user.username).trimmed().toLower(), local_server) + .trimmed() + .toLower(); + QString self_participant_state; + std::vector joined_participants; + joined_participants.reserve(event.participants.size()); + for (const auto& participant : event.participants) { + const QString participant_address = QString::fromStdString(participant.member_address).trimmed().toLower(); + if (participant_address == self_address) { + self_participant_state = QString::fromStdString(participant.state).trimmed().toLower(); + } + const QString participant_state = QString::fromStdString(participant.state).trimmed().toLower(); + if (participant_state != "joined") { + continue; + } + const bool participant_is_self = participant_address == self_address; + joined_participants.push_back(CallParticipantView{ + participant_address, + participant_is_self ? "You" : UsernameFromAddress(participant_address), + participant_is_self, + }); + } + std::sort( + joined_participants.begin(), + joined_participants.end(), + [](const CallParticipantView& lhs, const CallParticipantView& rhs) { + if (lhs.self != rhs.self) { + return lhs.self > rhs.self; + } + return lhs.label.trimmed().toLower() < rhs.label.trimmed().toLower(); + } + ); + + if (state == "ringing") { + if (self_participant_state == "left" || + self_participant_state == "declined" || + self_participant_state == "missed" || + self_participant_state == "removed") { + pending_outgoing_end_request_ = false; + StopAudioEngine(); + TransitionCallState( + "idle", + QString(), + conversation_id, + QString(), + self_participant_state.isEmpty() ? "left" : self_participant_state); + return; + } + if (IsCallState("idle")) { + pending_outgoing_end_request_ = false; + call_initiated_locally_ = false; + TransitionCallState( + "incoming_ringing", + call_id, + conversation_id, + title, + QString()); + emit IncomingCallReceived(call_state_); + return; + } + if (!IsCallState("outgoing_ringing") && + !IsCallState("incoming_ringing") && + !(pending_outgoing_end_request_ && IsCallState("ending"))) { + return; + } + if (!call_state_.call_id.isEmpty() && call_state_.call_id != call_id) { + return; + } + const QString next_state = IsCallState("incoming_ringing") ? "incoming_ringing" : "outgoing_ringing"; + TransitionCallState( + next_state, + call_id, + conversation_id, + title, + QString()); + if (pending_outgoing_end_request_ && !call_state_.call_id.trimmed().isEmpty()) { + VoiceCallEnd end_request; + end_request.call_id = call_state_.call_id.toStdString(); + end_request.reason = "ended"; + ws_client_.SendCallEnd(end_request); + pending_outgoing_end_request_ = false; + TransitionCallState( + "ending", + call_state_.call_id, + call_state_.conversation_id, + call_state_.peer_user_id, + "ending"); + } + return; + } + + if (state == "active") { + if (self_participant_state == "left" || + self_participant_state == "declined" || + self_participant_state == "missed" || + self_participant_state == "removed") { + pending_outgoing_end_request_ = false; + StopAudioEngine(); + TransitionCallState( + "idle", + QString(), + conversation_id, + QString(), + self_participant_state.isEmpty() ? "left" : self_participant_state); + return; + } + + // Another participant joined, but this local user has not joined yet. + // Keep local UI in ringing state until this user explicitly accepts. + if (self_participant_state != "joined") { + if (!call_state_.call_id.isEmpty() && call_state_.call_id != call_id) { + return; + } + if (IsCallState("idle")) { + pending_outgoing_end_request_ = false; + call_initiated_locally_ = false; + TransitionCallState( + "incoming_ringing", + call_id, + conversation_id, + title, + QString()); + emit IncomingCallReceived(call_state_); + return; + } + if (IsCallState("active")) { + StopAudioEngine(); + TransitionCallState( + "incoming_ringing", + call_id, + conversation_id, + title, + QString()); + emit IncomingCallReceived(call_state_); + return; + } + if (IsCallState("incoming_ringing") || IsCallState("outgoing_ringing") || IsCallState("ending")) { + const QString keep_state = IsCallState("outgoing_ringing") ? "outgoing_ringing" : "incoming_ringing"; + TransitionCallState( + keep_state, + call_id, + conversation_id, + title, + QString()); + } + return; + } + + const bool was_active_same_call = IsCallState("active") && call_state_.call_id == call_id; + if (!call_state_.call_id.isEmpty() && call_state_.call_id != call_id) { + return; + } + TransitionCallState( + "active", + call_id, + conversation_id, + title, + QString()); + call_state_.participants = joined_participants; + call_state_.reason = "Group call active"; + EmitCallState(); + + if (!was_active_same_call) { + const bool legacy_audio_enabled = EnvFlagEnabled("BLACKWIRE_ENABLE_LEGACY_CALL_AUDIO_WS", true); + if (legacy_audio_enabled) { + QString warning; + QString error; + if (!StartAudioEngineForActiveCall(&warning, &error)) { + ReportCallError(QString("Voice call audio start failed: %1").arg(error)); + if (!call_state_.call_id.trimmed().isEmpty()) { + VoiceCallEnd end_request; + end_request.call_id = call_state_.call_id.toStdString(); + end_request.reason = "audio_error"; + ws_client_.SendCallEnd(end_request); + } + StopAudioEngine(); + TransitionCallState("idle", QString(), QString(), QString(), "audio_error"); + return; + } + if (!warning.trimmed().isEmpty()) { + ReportCallError(warning); + } + } + } + return; + } + + if (state == "ended") { + if (!call_state_.call_id.isEmpty() && call_state_.call_id != call_id) { + return; + } + pending_outgoing_end_request_ = false; + StopAudioEngine(); + TransitionCallState( + "idle", + QString(), + conversation_id, + QString(), + "ended"); + } + }, [this](const WsEventCallRinging& event) { - if (!IsCallState("outgoing_ringing")) { + if (!IsCallState("outgoing_ringing") && !(pending_outgoing_end_request_ && IsCallState("ending"))) { return; } if (!call_state_.call_id.isEmpty() && call_state_.call_id != QString::fromStdString(event.call_id)) { @@ -202,20 +600,75 @@ ApplicationController::ApplicationController( QString::fromStdString( event.peer_user_address.empty() ? event.peer_user_id : event.peer_user_address), QString()); + if (pending_outgoing_end_request_ && !call_state_.call_id.trimmed().isEmpty()) { + VoiceCallEnd end_request; + end_request.call_id = call_state_.call_id.toStdString(); + end_request.reason = "ended"; + ws_client_.SendCallEnd(end_request); + pending_outgoing_end_request_ = false; + TransitionCallState( + "ending", + call_state_.call_id, + call_state_.conversation_id, + call_state_.peer_user_id, + "ending"); + } }, [this](const WsEventCallAccepted& event) { const QString call_id = QString::fromStdString(event.call_id); + const QString conversation_id = QString::fromStdString(event.conversation_id); + const auto* conversation = FindConversation(conversation_id.toStdString()); + if (conversation != nullptr && conversation->conversation_type == "group") { + // Group-call state is driven by call.group.state; ignore generic accepted events. + return; + } if (!call_state_.call_id.isEmpty() && call_state_.call_id != call_id) { return; } + if (IsCallState("incoming_ringing")) { + call_initiated_locally_ = false; + } else if (IsCallState("outgoing_ringing")) { + call_initiated_locally_ = true; + } TransitionCallState( "active", call_id, - QString::fromStdString(event.conversation_id), + conversation_id, QString::fromStdString( event.peer_user_address.empty() ? event.peer_user_id : event.peer_user_address), QString()); + call_state_.participants = BuildDirectCallParticipants(call_state_.peer_user_id); + EmitCallState(); + + const bool webrtc_enabled = EnvFlagEnabled("BLACKWIRE_ENABLE_WEBRTC_V2B2", false) || + event.call_mode == "webrtc"; + const bool legacy_audio_enabled = EnvFlagEnabled("BLACKWIRE_ENABLE_LEGACY_CALL_AUDIO_WS", true); + if (webrtc_enabled) { + VoiceCallWebRtcOffer offer; + const QString local_server = QString::fromStdString( + state_.user.home_server_onion.empty() ? ServerAuthority().toStdString() : state_.user.home_server_onion); + const QString self_address = QString("%1@%2") + .arg(QString::fromStdString(state_.user.username).trimmed().toLower(), local_server) + .trimmed() + .toLower(); + offer.call_id = event.call_id; + offer.sdp = "v=0\r\no=- 0 0 IN IP4 127.0.0.1\r\ns=Blackwire\r\nt=0 0\r\nm=audio 9 RTP/AVP 0\r\n"; + offer.source_user_address = self_address.toStdString(); + if (call_state_.peer_user_id.contains('@')) { + offer.target_user_address = call_state_.peer_user_id.trimmed().toLower().toStdString(); + } + ws_client_.SendCallWebRtcOffer(offer); + call_state_.reason = "WebRTC negotiating"; + if (event.ice_servers.is_array()) { + RecordDiagnostic( + QString("webrtc ice_servers=%1").arg(static_cast(event.ice_servers.size()))); + } + EmitCallState(); + if (!legacy_audio_enabled) { + return; + } + } QString warning; QString error; @@ -238,6 +691,7 @@ ApplicationController::ApplicationController( if (!call_state_.call_id.isEmpty() && call_state_.call_id != QString::fromStdString(event.call_id)) { return; } + pending_outgoing_end_request_ = false; StopAudioEngine(); TransitionCallState( "idle", @@ -250,6 +704,7 @@ ApplicationController::ApplicationController( if (!IsCallState("outgoing_ringing")) { return; } + pending_outgoing_end_request_ = false; StopAudioEngine(); TransitionCallState( "idle", @@ -264,6 +719,7 @@ ApplicationController::ApplicationController( return; } + pending_outgoing_end_request_ = false; StopAudioEngine(); TransitionCallState( "idle", @@ -293,6 +749,85 @@ ApplicationController::ApplicationController( } ReportCallError(message); }, + [this](const WsEventCallWebRtcOffer& event) { + RecordDiagnostic( + QString("received webrtc offer call_id=%1").arg(QString::fromStdString(event.call_id))); + if (IsCallState("active") && call_state_.call_id == QString::fromStdString(event.call_id)) { + const QString source_address = QString::fromStdString( + event.source_user_address.empty() ? event.from_user_address : event.source_user_address) + .trimmed() + .toLower(); + const QString local_server = QString::fromStdString( + state_.user.home_server_onion.empty() ? ServerAuthority().toStdString() : state_.user.home_server_onion); + const QString self_address = QString("%1@%2") + .arg(QString::fromStdString(state_.user.username).trimmed().toLower(), local_server) + .trimmed() + .toLower(); + VoiceCallWebRtcAnswer answer; + answer.call_id = event.call_id; + answer.sdp = "v=0\r\no=- 0 0 IN IP4 127.0.0.1\r\ns=Blackwire\r\nt=0 0\r\nm=audio 9 RTP/AVP 0\r\n"; + answer.source_user_address = self_address.toStdString(); + answer.target_user_address = source_address.toStdString(); + ws_client_.SendCallWebRtcAnswer(answer); + call_state_.reason = "WebRTC negotiating"; + EmitCallState(); + } + }, + [this](const WsEventCallWebRtcAnswer& event) { + RecordDiagnostic( + QString("received webrtc answer call_id=%1").arg(QString::fromStdString(event.call_id))); + if (IsCallState("active") && call_state_.call_id == QString::fromStdString(event.call_id)) { + call_state_.reason = "WebRTC connected"; + EmitCallState(); + } + }, + [this](const WsEventCallWebRtcIce& event) { + RecordDiagnostic( + QString("received webrtc ice call_id=%1").arg(QString::fromStdString(event.call_id))); + if (IsCallState("active") && call_state_.call_id == QString::fromStdString(event.call_id)) { + call_state_.reason = "WebRTC negotiating"; + EmitCallState(); + } + }, + [this](const WsEventGroupRenamed& event) { + const QString conversation_id = QString::fromStdString(event.conversation_id).trimmed(); + const QString next_group_name = QString::fromStdString(event.group_name).trimmed(); + if (conversation_id.isEmpty() || next_group_name.isEmpty()) { + return; + } + + QString previous_group_name; + for (auto& conversation : state_.conversations) { + if (conversation.id != conversation_id.toStdString()) { + continue; + } + previous_group_name = QString::fromStdString(conversation.group_name).trimmed(); + conversation.group_name = next_group_name.toStdString(); + break; + } + + QString dedupe_suffix; + if (!previous_group_name.isEmpty()) { + dedupe_suffix = QString("sync-%1->%2") + .arg(previous_group_name.toLower(), next_group_name.toLower()); + } else if (event.event_seq > 0) { + dedupe_suffix = QString("event-%1").arg(event.event_seq); + } else { + dedupe_suffix = QString("event-%1-%2") + .arg(next_group_name.toLower(), QString::fromStdString(event.actor_address).toLower()); + } + AppendGroupRenameHistoryEntry( + conversation_id, + QString::fromStdString(event.actor_address), + next_group_name, + dedupe_suffix); + + PersistState(); + RefreshConversationList(); + if (selected_conversation_id_ == conversation_id.toStdString()) { + emit ConversationSelected(conversation_id, RenderThread(selected_conversation_id_)); + } + }, [this](const std::string& error) { if (IsWebSocketAuthError(error)) { ReauthenticateWebSocket(); @@ -306,8 +841,16 @@ ApplicationController::ApplicationController( connection_status_ = connected ? "Connected" : "Disconnected"; RecordDiagnostic(QString("connection_status=%1").arg(connection_status_)); emit ConnectionStatusChanged(connection_status_); + if (!connected) { + emit UserPresenceChanged("offline"); + } else { + emit UserPresenceChanged(user_presence_status_); + RefreshPresenceCache(); + RefreshConversationList(); + } if (!connected && !IsCallState("idle")) { + pending_outgoing_end_request_ = false; StopAudioEngine(); TransitionCallState("idle", QString(), QString(), QString(), "connection_lost"); } @@ -320,6 +863,10 @@ void ApplicationController::Initialize() { if (state_.base_url.empty()) { state_.base_url = "http://localhost:8000"; } + user_presence_status_ = NormalizePresenceValue(QString::fromStdString(state_.social_preferences.presence_status)); + state_.social_preferences.presence_status = user_presence_status_.toStdString(); + const QString preferred_presence_status = user_presence_status_; + DecryptPlaintextCacheInState(); TransitionCallState("idle", QString(), QString(), QString(), QString()); LoadAudioDevices(); @@ -329,10 +876,12 @@ void ApplicationController::Initialize() { state_.has_user, state_.has_user ? QString::fromStdString(state_.user.username) : QString()); emit DeviceStateChanged(state_.has_device); + emit UserPresenceChanged(state_.has_user && state_.has_device ? user_presence_status_ : "offline"); if (state_.has_user && state_.has_device) { LoadConversations(); StartRealtime(); + SetPresenceStatus(preferred_presence_status); } } catch (const std::exception& ex) { const QString line = QString("Initialization failed: %1").arg(ex.what()); @@ -365,10 +914,13 @@ void ApplicationController::Register(const QString& username, const QString& pas state_.local_messages.clear(); state_.seen_message_ids.clear(); peer_device_cache_.clear(); + peer_presence_status_by_address_.clear(); pending_request_messages_.clear(); pending_request_senders_.clear(); + state_.social_preferences.presence_status = "active"; + user_presence_status_ = "active"; - SaveTokenPair(response.tokens); + SaveBootstrapToken(response.tokens); PersistState(); RecordDiagnostic(QString("register success user=%1").arg(QString::fromStdString(state_.user.username))); @@ -388,6 +940,10 @@ void ApplicationController::Login(const QString& username, const QString& passwo state_.user = response.user; state_.has_user = true; peer_device_cache_.clear(); + peer_presence_status_by_address_.clear(); + user_presence_status_ = NormalizePresenceValue(QString::fromStdString(state_.social_preferences.presence_status)); + state_.social_preferences.presence_status = user_presence_status_.toStdString(); + const QString preferred_presence_status = user_presence_status_; if (state_.has_device && state_.device.user_id != state_.user.id) { state_.has_device = false; @@ -396,11 +952,12 @@ void ApplicationController::Login(const QString& username, const QString& passwo state_.local_messages.clear(); state_.seen_message_ids.clear(); peer_device_cache_.clear(); + peer_presence_status_by_address_.clear(); pending_request_messages_.clear(); pending_request_senders_.clear(); } - SaveTokenPair(response.tokens); + SaveBootstrapToken(response.tokens); PersistState(); RecordDiagnostic(QString("login success user=%1").arg(QString::fromStdString(state_.user.username))); @@ -409,8 +966,40 @@ void ApplicationController::Login(const QString& username, const QString& passwo LoadAudioDevices(); if (state_.has_device) { - LoadConversations(); - StartRealtime(); + try { + std::string error; + const auto ik_private = secret_store_.GetSecret(SecretKey("ik_private"), &error); + if (!ik_private.has_value()) { + throw std::runtime_error("Device signing key missing"); + } + const std::string nonce = QUuid::createUuid().toString(QUuid::WithoutBraces).toStdString(); + const long long timestamp_ms = QDateTime::currentMSecsSinceEpoch(); + const std::string canonical = + "BIND_DEVICE\n" + state_.user.id + "\n" + state_.device.id + "\n" + nonce + "\n" + + std::to_string(timestamp_ms); + const std::string signature = crypto_.SignDetached(ik_private.value(), canonical); + const auto bound = api_client_.BindDevice( + state_.base_url, + RequireBootstrapToken(), + state_.device.id, + nonce, + timestamp_ms, + signature); + SaveTokenPair(bound.tokens); + DecryptPlaintextCacheInState(); + UploadCurrentDevicePrekeys(); + PersistState(); + LoadConversations(); + StartRealtime(); + SetPresenceStatus(preferred_presence_status); + } catch (const std::exception& ex) { + const QString line = QString("Device re-bind failed: %1").arg(ex.what()); + RecordDiagnostic(line); + emit ErrorOccurred(line); + state_.has_device = false; + PersistState(); + emit DeviceStateChanged(false); + } } } catch (const std::exception& ex) { const QString line = QString("Login failed: %1").arg(ex.what()); @@ -429,6 +1018,7 @@ void ApplicationController::Logout() { std::string error; secret_store_.DeleteSecret(SecretKey("access_token"), &error); secret_store_.DeleteSecret(SecretKey("refresh_token"), &error); + secret_store_.DeleteSecret(SecretKey("bootstrap_token"), &error); secret_store_.DeleteSecret(SecretKey("ik_private"), &error); secret_store_.DeleteSecret(SecretKey("enc_private"), &error); @@ -442,8 +1032,10 @@ void ApplicationController::Logout() { RecordDiagnostic("logout completed"); emit AuthStateChanged(false, QString()); emit DeviceStateChanged(false); + emit AccountDevicesChanged(std::vector{}); emit ConversationListChanged(std::vector{}); emit ConversationSelected(QString(), {}); + emit UserPresenceChanged("offline"); LoadAudioDevices(); } @@ -454,15 +1046,27 @@ void ApplicationController::SetupDevice(const QString& label) { request.label = label.toStdString(); request.ik_ed25519_pub = keys.ik_ed25519_public_b64; request.enc_x25519_pub = keys.enc_x25519_public_b64; + request.pub_sign_key = keys.ik_ed25519_public_b64; + request.pub_dh_key = keys.enc_x25519_public_b64; const auto operation = [this, &request]() { - return api_client_.RegisterDevice(state_.base_url, RequireAccessToken(), request); + return api_client_.RegisterDevice(state_.base_url, RequireBootstrapToken(), request); }; - const auto device = CallWithAuthRetryOnce(operation, [this]() { RefreshAccessToken(); }); + const auto response = operation(); + SaveTokenPair(response.tokens); - state_.device = device; + state_.device.id = response.tokens.device_uid; + state_.device.device_uid = response.tokens.device_uid; + state_.device.user_id = response.user.id; + state_.device.label = label.toStdString(); + state_.device.ik_ed25519_pub = keys.ik_ed25519_public_b64; + state_.device.enc_x25519_pub = keys.enc_x25519_public_b64; + state_.device.status = "active"; state_.has_device = true; + user_presence_status_ = NormalizePresenceValue(QString::fromStdString(state_.social_preferences.presence_status)); + state_.social_preferences.presence_status = user_presence_status_.toStdString(); + const QString preferred_presence_status = user_presence_status_; std::string error; if (!secret_store_.SetSecret(SecretKey("ik_private"), keys.ik_ed25519_private_b64, &error)) { @@ -472,6 +1076,7 @@ void ApplicationController::SetupDevice(const QString& label) { throw std::runtime_error(error); } + UploadCurrentDevicePrekeys(); PersistState(); RecordDiagnostic(QString("device setup complete label=%1").arg(label)); emit DeviceStateChanged(true); @@ -479,6 +1084,7 @@ void ApplicationController::SetupDevice(const QString& label) { LoadConversations(); StartRealtime(); + SetPresenceStatus(preferred_presence_status); } catch (const std::exception& ex) { const QString line = QString("Device setup failed: %1").arg(ex.what()); RecordDiagnostic(line); @@ -486,8 +1092,80 @@ void ApplicationController::SetupDevice(const QString& label) { } } +void ApplicationController::LoadAccountDevices() { + try { + if (!state_.has_user || !state_.has_device) { + emit AccountDevicesChanged(std::vector{}); + return; + } + + const auto operation = [this]() { + return api_client_.ListDevices(state_.base_url, RequireAccessToken()); + }; + + auto devices = CallWithAuthRetryOnce>(operation, [this]() { RefreshAccessToken(); }); + std::sort(devices.begin(), devices.end(), [](const DeviceOut& lhs, const DeviceOut& rhs) { + const std::string lhs_uid = lhs.device_uid.empty() ? lhs.id : lhs.device_uid; + const std::string rhs_uid = rhs.device_uid.empty() ? rhs.id : rhs.device_uid; + if (lhs.status != rhs.status) { + return lhs.status < rhs.status; + } + if (lhs.created_at != rhs.created_at) { + return lhs.created_at > rhs.created_at; + } + return lhs_uid < rhs_uid; + }); + emit AccountDevicesChanged(devices); + } catch (const std::exception& ex) { + const QString line = QString("Load account devices failed: %1").arg(ex.what()); + RecordDiagnostic(line); + emit ErrorOccurred(line); + } +} + +void ApplicationController::RevokeDevice(const QString& device_uid) { + try { + if (!state_.has_user || !state_.has_device) { + return; + } + const std::string target_uid = device_uid.trimmed().toStdString(); + if (target_uid.empty()) { + return; + } + + const auto operation = [this, &target_uid]() { + return api_client_.RevokeDevice(state_.base_url, RequireAccessToken(), target_uid); + }; + const auto revoked = CallWithAuthRetryOnce(operation, [this]() { RefreshAccessToken(); }); + const std::string revoked_uid = revoked.device_uid.empty() ? revoked.id : revoked.device_uid; + RecordDiagnostic(QString("revoked device uid=%1").arg(QString::fromStdString(revoked_uid))); + + peer_device_cache_.clear(); + if (state_.has_device && revoked_uid == state_.device.id) { + emit IntegrityWarningOccurred("This device has been revoked. Sign in again."); + Logout(); + return; + } + + LoadAccountDevices(); + } catch (const std::exception& ex) { + const QString line = QString("Revoke device failed: %1").arg(ex.what()); + RecordDiagnostic(line); + emit ErrorOccurred(line); + } +} + void ApplicationController::LoadConversations() { try { + std::unordered_map> previous_group_state; + previous_group_state.reserve(state_.conversations.size()); + for (const auto& conversation : state_.conversations) { + if (conversation.conversation_type != "group") { + continue; + } + previous_group_state[conversation.id] = std::make_pair(conversation.group_name, conversation.owner_address); + } + const auto operation = [this]() { return api_client_.ListConversations(state_.base_url, RequireAccessToken()); }; @@ -506,8 +1184,31 @@ void ApplicationController::LoadConversations() { QString(), QString::fromStdString(conv.created_at)); } + + if (conv.conversation_type != "group") { + continue; + } + const auto previous = previous_group_state.find(conv.id); + if (previous == previous_group_state.end()) { + continue; + } + const QString previous_name = QString::fromStdString(previous->second.first).trimmed(); + const QString next_name = QString::fromStdString(conv.group_name).trimmed(); + if (previous_name.isEmpty() || next_name.isEmpty() || previous_name == next_name) { + continue; + } + QString actor_address = QString::fromStdString(conv.owner_address).trimmed().toLower(); + if (actor_address.isEmpty()) { + actor_address = QString::fromStdString(previous->second.second).trimmed().toLower(); + } + AppendGroupRenameHistoryEntry( + QString::fromStdString(conv.id), + actor_address, + next_name, + QString("sync-%1->%2").arg(previous_name.toLower(), next_name.toLower())); } + RefreshPresenceCache(); PersistState(); RefreshConversationList(); RecordDiagnostic(QString("loaded conversations=%1").arg(static_cast(state_.conversations.size()))); @@ -544,6 +1245,7 @@ void ApplicationController::OpenConversationByPeer(const QString& username) { } UpsertConversationMeta(conversation_id, peer_username, peer_address, QString(), created_at); PersistState(); + RefreshPresenceCache(); RefreshConversationList(); SelectConversation(QString::fromStdString(conversation_id)); } catch (const std::exception& ex) { @@ -557,6 +1259,27 @@ void ApplicationController::SelectConversation(const QString& conversation_id) { try { selected_conversation_id_ = conversation_id.toStdString(); + const ConversationOut* selected = FindConversation(selected_conversation_id_); + if (selected != nullptr && selected->conversation_type == "group") { + const QString membership_state = QString::fromStdString(selected->membership_state).trimmed().toLower(); + if (membership_state == "invited") { + const auto accept_op = [this]() { + return api_client_.AcceptConversationInvite( + state_.base_url, + RequireAccessToken(), + selected_conversation_id_); + }; + (void)CallWithAuthRetryOnce(accept_op, [this]() { RefreshAccessToken(); }); + LoadConversations(); + RecordDiagnostic( + QString("accepted group invite conversation_id=%1") + .arg(QString::fromStdString(selected_conversation_id_))); + } else if (membership_state != "active" && membership_state != "none") { + emit ErrorOccurred("You are not an active member of this group."); + return; + } + } + const auto operation = [this]() { return api_client_.ListMessages( state_.base_url, @@ -580,9 +1303,37 @@ void ApplicationController::SelectConversation(const QString& conversation_id) { std::vector rebuilt; rebuilt.reserve(remote_messages.size()); + std::unordered_set rebuilt_ids; + const QString self_address = QString("%1@%2") + .arg( + QString::fromStdString(state_.user.username).trimmed().toLower(), + QString::fromStdString( + state_.user.home_server_onion.empty() + ? ServerAuthority().toStdString() + : state_.user.home_server_onion)) + .trimmed() + .toLower(); std::string error; const auto private_key = secret_store_.GetSecret(SecretKey("enc_private"), &error); + auto parse_time = [](const std::string& value) { + const QString iso = QString::fromStdString(value); + QDateTime parsed = QDateTime::fromString(iso, Qt::ISODateWithMs); + if (!parsed.isValid()) { + parsed = QDateTime::fromString(iso, Qt::ISODate); + } + return parsed; + }; + QDateTime oldest_remote_time; + for (const auto& message : remote_messages) { + const QDateTime created = parse_time(message.created_at); + if (!created.isValid()) { + continue; + } + if (!oldest_remote_time.isValid() || created < oldest_remote_time) { + oldest_remote_time = created; + } + } for (const auto& message : remote_messages) { auto existing = existing_by_id.find(message.id); @@ -590,7 +1341,15 @@ void ApplicationController::SelectConversation(const QString& conversation_id) { if (existing->second.plaintext.empty()) { existing->second.plaintext = ExtractLegacyPlaintext(QString::fromStdString(existing->second.rendered_text)).toStdString(); } + if (existing->second.sender_address.empty()) { + if (!message.sender_address.empty()) { + existing->second.sender_address = message.sender_address; + } else if (existing->second.sender_user_id == state_.user.id) { + existing->second.sender_address = self_address.toStdString(); + } + } rebuilt.push_back(existing->second); + rebuilt_ids.insert(existing->second.id); state_.MarkMessageSeen(message.id); continue; } @@ -612,14 +1371,53 @@ void ApplicationController::SelectConversation(const QString& conversation_id) { local.id = message.id; local.conversation_id = message.conversation_id; local.sender_user_id = message.sender_user_id; + local.sender_address = message.sender_address.empty() && message.sender_user_id == state_.user.id + ? self_address.toStdString() + : message.sender_address; local.created_at = message.created_at; local.rendered_text = RenderMessage(message, plaintext).toStdString(); local.plaintext = plaintext; rebuilt.push_back(local); + rebuilt_ids.insert(local.id); state_.MarkMessageSeen(message.id); } + if (existing_it != state_.local_messages.end()) { + for (const auto& existing : existing_it->second) { + if (rebuilt_ids.contains(existing.id)) { + continue; + } + // Server listing is device-copy scoped and paginated, so preserve + // self/system/older-cached entries that are not in this window. + const bool self_sent = existing.sender_user_id == state_.user.id; + const bool local_system_entry = existing.sender_user_id.empty(); + bool older_than_fetch_window = false; + if (oldest_remote_time.isValid()) { + const QDateTime existing_time = parse_time(existing.created_at); + older_than_fetch_window = existing_time.isValid() && existing_time < oldest_remote_time; + } + if (self_sent || local_system_entry || older_than_fetch_window) { + LocalMessage carry = existing; + if (carry.sender_address.empty()) { + carry.sender_address = self_address.toStdString(); + } + rebuilt.push_back(std::move(carry)); + } + } + } + std::stable_sort(rebuilt.begin(), rebuilt.end(), [&parse_time](const LocalMessage& lhs, const LocalMessage& rhs) { + const QDateTime lhs_time = parse_time(lhs.created_at); + const QDateTime rhs_time = parse_time(rhs.created_at); + if (lhs_time.isValid() && rhs_time.isValid() && lhs_time != rhs_time) { + return lhs_time < rhs_time; + } + if (lhs.created_at != rhs.created_at) { + return lhs.created_at < rhs.created_at; + } + return lhs.id < rhs.id; + }); + state_.local_messages[selected_conversation_id_] = rebuilt; if (!rebuilt.empty()) { const auto& last = rebuilt.back(); @@ -644,13 +1442,433 @@ void ApplicationController::SelectConversation(const QString& conversation_id) { } } +bool ApplicationController::DismissDirectConversation(const QString& conversation_id) { + try { + const std::string conversation_key = conversation_id.trimmed().toStdString(); + if (conversation_key.empty()) { + return false; + } + const ConversationOut* selected = FindConversation(conversation_key); + if (selected == nullptr || selected->conversation_type != "direct") { + emit ErrorOccurred("Only direct messages can be dismissed."); + return false; + } + + state_.dismissed_conversation_ids.insert(conversation_key); + state_.local_messages.erase(conversation_key); + state_.conversation_meta.erase(conversation_key); + pending_request_messages_.erase(conversation_key); + pending_request_senders_.erase(conversation_key); + + if (selected_conversation_id_ == conversation_key) { + selected_conversation_id_.clear(); + emit ConversationSelected(QString(), {}); + } + + PersistState(); + RefreshConversationList(); + RecordDiagnostic(QString("dismissed direct conversation id=%1").arg(conversation_id)); + return true; + } catch (const std::exception& ex) { + const QString line = QString("Dismiss direct conversation failed: %1").arg(ex.what()); + RecordDiagnostic(line); + emit ErrorOccurred(line); + return false; + } +} + +bool ApplicationController::LeaveGroupConversation(const QString& conversation_id) { + try { + const std::string conversation_key = conversation_id.trimmed().toStdString(); + if (conversation_key.empty()) { + return false; + } + const ConversationOut* selected = FindConversation(conversation_key); + if (selected == nullptr || selected->conversation_type != "group") { + emit ErrorOccurred("Only group conversations can be left."); + return false; + } + + const auto operation = [this, &conversation_key]() { + return api_client_.LeaveConversationGroup( + state_.base_url, + RequireAccessToken(), + conversation_key); + }; + (void)CallWithAuthRetryOnce(operation, [this]() { RefreshAccessToken(); }); + + state_.local_messages.erase(conversation_key); + state_.conversation_meta.erase(conversation_key); + state_.blocked_conversation_ids.erase(conversation_key); + state_.dismissed_conversation_ids.erase(conversation_key); + pending_request_messages_.erase(conversation_key); + pending_request_senders_.erase(conversation_key); + + if (selected_conversation_id_ == conversation_key) { + selected_conversation_id_.clear(); + emit ConversationSelected(QString(), {}); + } + + LoadConversations(); + RecordDiagnostic(QString("left group conversation id=%1").arg(conversation_id)); + return true; + } catch (const ApiException& ex) { + QString message = QString::fromStdString(ex.what()); + if (ex.status_code() == 404 && message.contains("disabled", Qt::CaseInsensitive)) { + message = "Group DMs are disabled on this federation."; + } + const QString line = QString("Leave group conversation failed: %1").arg(message); + RecordDiagnostic(line); + emit ErrorOccurred(message); + return false; + } catch (const std::exception& ex) { + const QString line = QString("Leave group conversation failed: %1").arg(ex.what()); + RecordDiagnostic(line); + emit ErrorOccurred(line); + return false; + } +} + +bool ApplicationController::CreateGroupFromCurrentDm() { + try { + if (!state_.has_user || !state_.has_device) { + emit ErrorOccurred("Sign in before creating a group."); + return false; + } + if (selected_conversation_id_.empty()) { + emit ErrorOccurred("Select a direct message first."); + return false; + } + + const ConversationOut* selected = FindConversation(selected_conversation_id_); + if (selected == nullptr || selected->conversation_type != "direct") { + emit ErrorOccurred("Group creation is only available from a direct message."); + return false; + } + + const QString peer_address = ResolvePeerAddressForConversation(selected->id).trimmed().toLower(); + if (peer_address.isEmpty()) { + emit ErrorOccurred("Cannot determine the selected peer address."); + return false; + } + + const auto parsed_peer = ParsePeerAddress(peer_address); + if (!parsed_peer.has_value() || parsed_peer->username.trimmed().isEmpty()) { + emit ErrorOccurred("Selected DM peer is invalid for group creation."); + return false; + } + + const QString self_username = QString::fromStdString(state_.user.username).trimmed().toLower(); + const QString peer_username = parsed_peer->username.trimmed().toLower(); + QStringList name_parts; + if (!self_username.isEmpty()) { + name_parts.push_back(self_username); + } + if (!peer_username.isEmpty()) { + name_parts.push_back(peer_username); + } + name_parts.removeDuplicates(); + std::sort(name_parts.begin(), name_parts.end(), [](const QString& lhs, const QString& rhs) { + return lhs.toLower() < rhs.toLower(); + }); + const QString group_name = name_parts.isEmpty() ? "group" : name_parts.join(" + "); + + CreateGroupConversationRequest request; + request.name = group_name.toStdString(); + request.member_addresses = {peer_address.toStdString()}; + + const auto operation = [this, &request]() { + return api_client_.CreateGroup(state_.base_url, RequireAccessToken(), request); + }; + const auto created = CallWithAuthRetryOnce(operation, [this]() { RefreshAccessToken(); }); + + LoadConversations(); + RevealConversation(created.id, false); + SelectConversation(QString::fromStdString(created.id)); + RecordDiagnostic( + QString("created group conversation id=%1 with peer=%2") + .arg(QString::fromStdString(created.id), peer_address)); + return true; + } catch (const ApiException& ex) { + QString message = QString::fromStdString(ex.what()); + if (ex.status_code() == 404 && message.contains("disabled", Qt::CaseInsensitive)) { + message = "Group DMs are disabled on this federation."; + } + const QString line = QString("Create group failed: %1").arg(message); + RecordDiagnostic(line); + emit ErrorOccurred(message); + return false; + } catch (const std::exception& ex) { + const QString line = QString("Create group failed: %1").arg(ex.what()); + RecordDiagnostic(line); + emit ErrorOccurred(line); + return false; + } +} + +GroupInvitePickerView ApplicationController::LoadInvitableContactsForCurrentGroup(const QString& query) { + GroupInvitePickerView picker; + try { + bool max_members_ok = false; + int configured_group_max_members = qEnvironmentVariableIntValue("BLACKWIRE_GROUP_MAX_MEMBERS", &max_members_ok); + if (!max_members_ok || configured_group_max_members <= 0) { + configured_group_max_members = 32; + } + if (!state_.has_user || !state_.has_device || selected_conversation_id_.empty()) { + return picker; + } + + const ConversationOut* selected = FindConversation(selected_conversation_id_); + if (selected == nullptr || selected->conversation_type != "group" || !selected->can_manage_members) { + return picker; + } + + const auto members_op = [this]() { + return api_client_.ListConversationMembers( + state_.base_url, + RequireAccessToken(), + selected_conversation_id_); + }; + const auto members = CallWithAuthRetryOnce>( + members_op, + [this]() { RefreshAccessToken(); }); + + std::unordered_set excluded_addresses; + std::unordered_set active_or_invited_identities; + for (const auto& member : members) { + const std::string normalized = QString::fromStdString(member.member_address).trimmed().toLower().toStdString(); + const QString status = QString::fromStdString(member.status).trimmed().toLower(); + if (status == "active" || status == "invited") { + if (!normalized.empty()) { + excluded_addresses.insert(normalized); + } + const QString identity = !member.member_user_id.empty() + ? QString("user:%1") + .arg(QString::fromStdString(member.member_user_id).trimmed().toLower()) + : QString("addr:%1").arg(QString::fromStdString(normalized)); + active_or_invited_identities.insert(identity.toStdString()); + } + } + picker.remaining_slots = std::max(0, configured_group_max_members - static_cast(active_or_invited_identities.size())); + + const QString local_server = QString::fromStdString( + state_.user.home_server_onion.empty() + ? ServerAuthority().toStdString() + : state_.user.home_server_onion) + .trimmed() + .toLower(); + const QString self_address = QString("%1@%2") + .arg(QString::fromStdString(state_.user.username).trimmed().toLower(), local_server); + excluded_addresses.insert(self_address.toStdString()); + + const QString normalized_query = query.trimmed().toLower(); + std::unordered_set seen_candidates; + + for (const auto& conv : state_.conversations) { + if (conv.conversation_type != "direct") { + continue; + } + if (state_.blocked_conversation_ids.contains(conv.id)) { + continue; + } + + const QString peer_address = ResolvePeerAddressForConversation(conv.id).trimmed().toLower(); + const std::string key = peer_address.toStdString(); + if (key.empty() || excluded_addresses.contains(key) || seen_candidates.contains(key)) { + continue; + } + + QString title = FriendlyConversationTitle(conv.id).trimmed(); + if (title.isEmpty()) { + title = peer_address.section('@', 0, 0); + } + const QString subtitle = peer_address; + const QString haystack = QString("%1 %2").arg(title, subtitle).toLower(); + if (!normalized_query.isEmpty() && !haystack.contains(normalized_query)) { + continue; + } + + GroupInviteCandidateView candidate; + candidate.peer_address = peer_address; + candidate.title = title; + candidate.subtitle = subtitle; + candidate.status = PresenceForPeerAddress(peer_address); + picker.candidates.push_back(candidate); + seen_candidates.insert(key); + } + + std::sort( + picker.candidates.begin(), + picker.candidates.end(), + [](const GroupInviteCandidateView& lhs, const GroupInviteCandidateView& rhs) { + return lhs.title.trimmed().toLower() < rhs.title.trimmed().toLower(); + }); + } catch (const std::exception& ex) { + const QString line = QString("Load invitable contacts failed: %1").arg(ex.what()); + RecordDiagnostic(line); + emit ErrorOccurred(line); + } + return picker; +} + +bool ApplicationController::InviteContactsToCurrentGroup(const std::vector& peer_addresses) { + try { + if (!IsSelectedConversationOwnerManagedGroup()) { + emit ErrorOccurred("Only the group owner can invite members."); + return false; + } + if (selected_conversation_id_.empty()) { + emit ErrorOccurred("Select a group conversation first."); + return false; + } + + std::vector normalized_addresses; + normalized_addresses.reserve(peer_addresses.size()); + std::unordered_set seen; + for (const auto& raw : peer_addresses) { + const QString normalized = raw.trimmed().toLower(); + if (normalized.isEmpty()) { + continue; + } + const auto parsed = ParsePeerAddress(normalized); + if (!parsed.has_value()) { + continue; + } + const std::string canonical = parsed->has_server + ? QString("%1@%2") + .arg(parsed->username.trimmed().toLower(), parsed->server_authority.trimmed().toLower()) + .toStdString() + : parsed->username.trimmed().toLower().toStdString(); + if (canonical.empty() || seen.contains(canonical)) { + continue; + } + seen.insert(canonical); + normalized_addresses.push_back(canonical); + } + + if (normalized_addresses.empty()) { + emit ErrorOccurred("Select at least one contact to invite."); + return false; + } + + GroupInviteRequest request; + request.member_addresses = normalized_addresses; + const auto operation = [this, &request]() { + return api_client_.InviteConversationMembers( + state_.base_url, + RequireAccessToken(), + selected_conversation_id_, + request); + }; + (void)CallWithAuthRetryOnce>(operation, [this]() { RefreshAccessToken(); }); + + const QString keep_selected = QString::fromStdString(selected_conversation_id_); + LoadConversations(); + SelectConversation(keep_selected); + RecordDiagnostic( + QString("invited %1 member(s) to group id=%2") + .arg(static_cast(normalized_addresses.size())) + .arg(keep_selected)); + return true; + } catch (const ApiException& ex) { + QString message = QString::fromStdString(ex.what()); + if (ex.status_code() == 404 && message.contains("disabled", Qt::CaseInsensitive)) { + message = "Group DMs are disabled on this federation."; + } + const QString line = QString("Group invite failed: %1").arg(message); + RecordDiagnostic(line); + emit ErrorOccurred(message); + return false; + } catch (const std::exception& ex) { + const QString line = QString("Group invite failed: %1").arg(ex.what()); + RecordDiagnostic(line); + emit ErrorOccurred(line); + return false; + } +} + +bool ApplicationController::RenameSelectedGroup(const QString& new_name) { + try { + if (!IsSelectedConversationOwnerManagedGroup()) { + emit ErrorOccurred("Only the group owner can rename this group."); + return false; + } + if (selected_conversation_id_.empty()) { + emit ErrorOccurred("Select a group conversation first."); + return false; + } + + const QString normalized_name = new_name.trimmed(); + if (normalized_name.isEmpty()) { + emit ErrorOccurred("Group name cannot be empty."); + return false; + } + + const ConversationOut* selected = FindConversation(selected_conversation_id_); + if (selected == nullptr || selected->conversation_type != "group") { + emit ErrorOccurred("Selected conversation is not a group."); + return false; + } + const QString current_name = QString::fromStdString(selected->group_name).trimmed(); + if (!current_name.isEmpty() && current_name == normalized_name) { + return true; + } + + GroupRenameRequest request; + request.name = normalized_name.toStdString(); + const auto operation = [this, &request]() { + return api_client_.RenameConversationGroup( + state_.base_url, + RequireAccessToken(), + selected_conversation_id_, + request); + }; + (void)CallWithAuthRetryOnce(operation, [this]() { RefreshAccessToken(); }); + + const QString keep_selected = QString::fromStdString(selected_conversation_id_); + LoadConversations(); + SelectConversation(keep_selected); + RecordDiagnostic( + QString("renamed group id=%1 to '%2'") + .arg(keep_selected, normalized_name)); + return true; + } catch (const std::exception& ex) { + const QString line = QString("Group rename failed: %1").arg(ex.what()); + RecordDiagnostic(line); + emit ErrorOccurred(line); + return false; + } +} + +bool ApplicationController::IsSelectedConversationOwnerManagedGroup() const { + if (selected_conversation_id_.empty()) { + return false; + } + const ConversationOut* selected = FindConversation(selected_conversation_id_); + if (selected == nullptr) { + return false; + } + return selected->conversation_type == "group" && selected->can_manage_members; +} + void ApplicationController::SendMessageToPeer(const QString& peer_username, const QString& message_text) { try { if (message_text.trimmed().isEmpty()) { return; } + const QString home_server = QString::fromStdString( + state_.user.home_server_onion.empty() ? ServerAuthority().toStdString() : state_.user.home_server_onion); + const std::string self_address = + QString("%1@%2") + .arg(QString::fromStdString(state_.user.username).trimmed().toLower(), home_server) + .toStdString(); + + bool group_send = false; QString normalized_peer; + std::string conversation_id; + std::vector group_recipients; + if (!peer_username.trimmed().isEmpty()) { QString error; const auto normalized = NormalizePeerUsername(peer_username, &error); @@ -659,7 +1877,389 @@ void ApplicationController::SendMessageToPeer(const QString& peer_username, cons return; } normalized_peer = normalized->trimmed().toLower(); + conversation_id = ResolveConversationIdForPeer(normalized_peer); + RevealConversation(conversation_id, false); } else if (!selected_conversation_id_.empty()) { + const ConversationOut* selected = FindConversation(selected_conversation_id_); + if (selected != nullptr && selected->conversation_type == "group") { + group_send = true; + conversation_id = selected->id; + const auto recipients_op = [this, &conversation_id]() { + return api_client_.GetConversationRecipients( + state_.base_url, + RequireAccessToken(), + conversation_id); + }; + const auto recipients = CallWithAuthRetryOnce( + recipients_op, + [this]() { RefreshAccessToken(); }); + group_recipients = recipients.recipients; + RevealConversation(conversation_id, false); + } else { + const auto meta_it = state_.conversation_meta.find(selected_conversation_id_); + if (meta_it == state_.conversation_meta.end() || + (meta_it->second.peer_address.empty() && meta_it->second.peer_username.empty())) { + emit ErrorOccurred("Missing peer identity for selected conversation. Re-open the DM from username."); + return; + } + if (!meta_it->second.peer_address.empty()) { + normalized_peer = QString::fromStdString(meta_it->second.peer_address).trimmed().toLower(); + } else { + normalized_peer = QString::fromStdString(meta_it->second.peer_username).trimmed().toLower(); + } + conversation_id = ResolveConversationIdForPeer(normalized_peer); + RevealConversation(conversation_id, false); + } + } else { + emit ErrorOccurred("Please enter a valid username or username@onion"); + return; + } + + struct SendTarget { + std::string recipient_address; + DeviceOut device; + bool self_mirror = false; + std::optional prekey; + }; + + std::vector targets; + std::set seen_targets; + std::string peer_address; + + if (group_send) { + for (const auto& recipient : group_recipients) { + DeviceOut device = recipient.device; + if (device.status == "revoked") { + continue; + } + const std::string target_uid = device.device_uid.empty() ? device.id : device.device_uid; + if (target_uid.empty() || seen_targets.contains(target_uid) || target_uid == state_.device.id) { + continue; + } + seen_targets.insert(target_uid); + if (device.device_uid.empty()) { + device.device_uid = target_uid; + } + if (device.id.empty()) { + device.id = target_uid; + } + targets.push_back(SendTarget{ + recipient.member_address, + device, + recipient.member_address == self_address, + recipient.prekey, + }); + } + } else { + const std::vector recipient_devices = ResolveRecipientDevices(normalized_peer); + const std::vector own_devices = ResolveOwnActiveDevices(); + peer_address = normalized_peer.contains('@') + ? normalized_peer.toStdString() + : QString("%1@%2").arg(normalized_peer, home_server).toStdString(); + + for (const auto& device : recipient_devices) { + const std::string target_uid = device.device_uid.empty() ? device.id : device.device_uid; + if (target_uid.empty() || seen_targets.contains(target_uid)) { + continue; + } + seen_targets.insert(target_uid); + DeviceOut normalized_device = device; + if (normalized_device.device_uid.empty()) { + normalized_device.device_uid = target_uid; + } + if (normalized_device.id.empty()) { + normalized_device.id = target_uid; + } + targets.push_back(SendTarget{peer_address, normalized_device, false, std::nullopt}); + } + + for (const auto& device : own_devices) { + const std::string target_uid = device.device_uid.empty() ? device.id : device.device_uid; + if (target_uid.empty() || target_uid == state_.device.id || seen_targets.contains(target_uid)) { + continue; + } + seen_targets.insert(target_uid); + DeviceOut normalized_device = device; + if (normalized_device.device_uid.empty()) { + normalized_device.device_uid = target_uid; + } + if (normalized_device.id.empty()) { + normalized_device.id = target_uid; + } + targets.push_back(SendTarget{self_address, normalized_device, true, std::nullopt}); + } + } + + bool use_ratchet_mode = PreferRatchetV2b1() && !targets.empty(); + if (use_ratchet_mode) { + for (const auto& target : targets) { + if (!DeviceSupportsMessageMode(target.device, "ratchet_v0_2b1")) { + use_ratchet_mode = false; + break; + } + } + } + std::unordered_map prekeys_by_device; + for (const auto& target : targets) { + const std::string target_uid = + target.device.device_uid.empty() ? target.device.id : target.device.device_uid; + if (target_uid.empty() || !target.prekey.has_value()) { + continue; + } + prekeys_by_device[target_uid] = target.prekey.value(); + } + if (use_ratchet_mode) { + if (!group_send) { + try { + const auto prekey_op = [this, &peer_address]() { + return api_client_.ResolvePrekeys(state_.base_url, RequireAccessToken(), peer_address); + }; + const auto prekeys = + CallWithAuthRetryOnce(prekey_op, [this]() { RefreshAccessToken(); }); + for (const auto& item : prekeys.devices) { + prekeys_by_device[item.device_uid] = item; + } + } catch (const std::exception& ex) { + use_ratchet_mode = false; + RecordDiagnostic(QString("ratchet fallback: prekey resolve failed (%1)").arg(ex.what())); + } + } + if (use_ratchet_mode) { + for (const auto& target : targets) { + if (target.self_mirror) { + continue; + } + const std::string target_uid = + target.device.device_uid.empty() ? target.device.id : target.device.device_uid; + const auto it = prekeys_by_device.find(target_uid); + if (it == prekeys_by_device.end() || !it->second.signed_prekey.has_value()) { + use_ratchet_mode = false; + break; + } + } + if (!use_ratchet_mode) { + RecordDiagnostic("ratchet fallback: missing signed prekeys for one or more recipient devices"); + } + } + } + + MessageSendRequest request; + request.conversation_id = conversation_id; + request.encryption_mode = use_ratchet_mode ? "ratchet_v0_2b1" : "sealedbox_v0_2a"; + if (use_ratchet_mode) { + RecordDiagnostic("message send mode=ratchet_v0_2b1"); + } else if (PreferRatchetV2b1()) { + RecordDiagnostic("message send mode=sealedbox_v0_2a (fallback)"); + } + request.client_message_id = QUuid::createUuid().toString(QUuid::WithoutBraces).toStdString(); + request.sent_at_ms = QDateTime::currentMSecsSinceEpoch(); + const std::string sender_chain_key = conversation_id + "|" + state_.device.id; + const auto chain_it = state_.last_verified_chain_hash_by_conversation_sender.find(sender_chain_key); + request.sender_prev_hash = + chain_it == state_.last_verified_chain_hash_by_conversation_sender.end() ? "" : chain_it->second; + + struct PendingEnvelope { + CipherEnvelope envelope; + std::string ciphertext_hash; + std::string aad_hash; + }; + std::vector pending; + + for (const auto& target : targets) { + const auto& device = target.device; + const std::string target_uid = device.device_uid.empty() ? device.id : device.device_uid; + if (target_uid.empty()) { + continue; + } + PendingEnvelope pending_env; + pending_env.envelope.recipient_user_address = target.recipient_address; + pending_env.envelope.recipient_device_uid = target_uid; + pending_env.envelope.recipient_device_id = target_uid; + pending_env.envelope.ciphertext_b64 = + crypto_.EncryptForRecipient(device.enc_x25519_pub, message_text.toStdString()); + pending_env.envelope.aad_b64.clear(); + pending_env.envelope.sender_device_pubkey = state_.device.ik_ed25519_pub; + if (use_ratchet_mode) { + pending_env.envelope.ratchet_header = nlohmann::json{ + {"v", "dr_v1"}, + {"dh_pub", state_.device.enc_x25519_pub}, + {"n", 0}, + {"pn", 0}, + }; + if (target.self_mirror) { + pending_env.envelope.ratchet_header["self_mirror"] = true; + } + const auto prekey_it = prekeys_by_device.find(target_uid); + if (prekey_it != prekeys_by_device.end()) { + nlohmann::json init = { + {"scheme", "x3dh_v1"}, + {"sender_ephemeral_pub", state_.device.enc_x25519_pub}, + {"opk_missing", prekey_it->second.opk_missing}, + }; + if (prekey_it->second.signed_prekey.has_value()) { + init["signed_prekey_id"] = prekey_it->second.signed_prekey->key_id; + } + if (prekey_it->second.one_time_prekey.has_value()) { + init["one_time_prekey_id"] = prekey_it->second.one_time_prekey->key_id; + } else { + init["one_time_prekey_id"] = nullptr; + } + pending_env.envelope.ratchet_init = init; + } + } + const QByteArray ciphertext_bytes = QByteArray::fromBase64( + QByteArray::fromStdString(pending_env.envelope.ciphertext_b64)); + pending_env.ciphertext_hash = crypto_.Sha256(ciphertext_bytes.toStdString()); + pending_env.aad_hash = crypto_.Sha256(std::string()); + pending.push_back(pending_env); + } + + if (pending.empty()) { + use_ratchet_mode = false; + request.encryption_mode = "sealedbox_v0_2a"; + DeviceOut self_device; + self_device.id = state_.device.id; + self_device.device_uid = state_.device.id; + self_device.user_id = state_.user.id; + self_device.label = state_.device.label; + self_device.ik_ed25519_pub = state_.device.ik_ed25519_pub; + self_device.enc_x25519_pub = state_.device.enc_x25519_pub; + self_device.status = "active"; + + PendingEnvelope fallback_env; + fallback_env.envelope.recipient_user_address = self_address; + fallback_env.envelope.recipient_device_uid = state_.device.id; + fallback_env.envelope.recipient_device_id = state_.device.id; + fallback_env.envelope.ciphertext_b64 = + crypto_.EncryptForRecipient(self_device.enc_x25519_pub, message_text.toStdString()); + fallback_env.envelope.aad_b64.clear(); + fallback_env.envelope.sender_device_pubkey = state_.device.ik_ed25519_pub; + const QByteArray fallback_ciphertext_bytes = QByteArray::fromBase64( + QByteArray::fromStdString(fallback_env.envelope.ciphertext_b64)); + fallback_env.ciphertext_hash = crypto_.Sha256(fallback_ciphertext_bytes.toStdString()); + fallback_env.aad_hash = crypto_.Sha256(std::string()); + pending.push_back(fallback_env); + + RecordDiagnostic("send fallback: no recipient devices resolved; message retained for sender device"); + } + + std::vector chain_material; + chain_material.reserve(pending.size()); + for (const auto& item : pending) { + const std::string target_uid = + item.envelope.recipient_device_uid.empty() ? item.envelope.recipient_device_id + : item.envelope.recipient_device_uid; + chain_material.push_back(target_uid + ":" + item.ciphertext_hash + ":" + item.aad_hash); + } + request.sender_chain_hash = AggregateChainHash( + crypto_, + request.sender_prev_hash, + request.client_message_id, + request.sent_at_ms, + chain_material); + + std::string secret_error; + const auto sender_ik_private = secret_store_.GetSecret(SecretKey("ik_private"), &secret_error); + if (!sender_ik_private.has_value()) { + throw std::runtime_error("Device signing key unavailable"); + } + + for (auto& item : pending) { + const std::string canonical = CanonicalMessageSignature( + self_address, + state_.device.id, + item.envelope.recipient_user_address, + item.envelope.recipient_device_uid.empty() ? item.envelope.recipient_device_id + : item.envelope.recipient_device_uid, + request.client_message_id, + request.sent_at_ms, + request.sender_prev_hash, + request.sender_chain_hash, + item.ciphertext_hash, + item.aad_hash); + item.envelope.signature_b64 = crypto_.SignDetached(sender_ik_private.value(), canonical); + request.envelopes.push_back(item.envelope); + } + + const auto send_op = [this, &request]() { + return api_client_.SendMessage(state_.base_url, RequireAccessToken(), request); + }; + const auto sent = CallWithAuthRetryOnce(send_op, [this]() { RefreshAccessToken(); }); + + const std::string resolved_conversation_id = + sent.message.conversation_id.empty() ? conversation_id : sent.message.conversation_id; + selected_conversation_id_ = resolved_conversation_id; + state_.MarkMessageSeen(sent.message.id); + + LocalMessage local; + local.id = sent.message.id; + local.conversation_id = resolved_conversation_id; + local.sender_user_id = state_.user.id; + local.sender_address = self_address; + local.created_at = sent.message.created_at; + local.rendered_text = RenderMessage(sent.message, message_text.toStdString()).toStdString(); + local.plaintext = message_text.toStdString(); + state_.local_messages[selected_conversation_id_].push_back(local); + state_.last_verified_chain_hash_by_conversation_sender[sender_chain_key] = sent.message.sender_chain_hash; + if (resolved_conversation_id != conversation_id) { + const std::string canonical_chain_key = resolved_conversation_id + "|" + state_.device.id; + state_.last_verified_chain_hash_by_conversation_sender[canonical_chain_key] = sent.message.sender_chain_hash; + } + + if (group_send) { + UpsertConversationMeta( + selected_conversation_id_, + QString(), + QString(), + MessagePreview(message_text), + QString::fromStdString(sent.message.created_at)); + } else { + UpsertConversationMeta( + selected_conversation_id_, + normalized_peer.contains('@') ? normalized_peer.section('@', 0, 0) : normalized_peer, + normalized_peer.contains('@') ? normalized_peer : QString(), + MessagePreview(message_text), + QString::fromStdString(sent.message.created_at)); + } + if (resolved_conversation_id != conversation_id) { + try { + LoadConversations(); + } catch (...) { + // Keep optimistic local state even if conversation refresh fails. + } + } + + PersistState(); + RefreshConversationList(); + emit MessageSendSucceeded( + QString::fromStdString(selected_conversation_id_), + QString::fromStdString(sent.message.id)); + emit ConversationSelected( + QString::fromStdString(selected_conversation_id_), + RenderThread(selected_conversation_id_)); + } catch (const std::exception& ex) { + const QString line = QString("Send message failed: %1").arg(ex.what()); + RecordDiagnostic(line); + emit ErrorOccurred(line); + } +} + +void ApplicationController::SendFileToPeer(const QString& peer_username, const QString& file_path) { + QString normalized_peer; + bool selected_group_conversation = false; + if (!peer_username.trimmed().isEmpty()) { + QString error; + const auto normalized = NormalizePeerUsername(peer_username, &error); + if (!normalized.has_value()) { + emit ErrorOccurred(error); + return; + } + normalized_peer = normalized->trimmed().toLower(); + } else if (!selected_conversation_id_.empty()) { + const ConversationOut* selected = FindConversation(selected_conversation_id_); + if (selected != nullptr && selected->conversation_type == "group") { + selected_group_conversation = true; + } else { const auto meta_it = state_.conversation_meta.find(selected_conversation_id_); if (meta_it == state_.conversation_meta.end() || (meta_it->second.peer_address.empty() && meta_it->second.peer_username.empty())) { @@ -671,62 +2271,118 @@ void ApplicationController::SendMessageToPeer(const QString& peer_username, cons } else { normalized_peer = QString::fromStdString(meta_it->second.peer_username).trimmed().toLower(); } - } else { - emit ErrorOccurred("Please enter a valid username or username@onion"); - return; } + } else { + emit ErrorOccurred("Please enter a valid username or username@onion"); + return; + } + + const QString normalized_path = file_path.trimmed(); + if (normalized_path.isEmpty()) { + return; + } - const std::string conversation_id = ResolveConversationIdForPeer(normalized_peer); - RevealConversation(conversation_id, false); - const DeviceOut recipient_device = ResolveRecipientDevice(normalized_peer); + QFile input(normalized_path); + if (!input.open(QIODevice::ReadOnly)) { + emit ErrorOccurred(QString("Unable to open file: %1").arg(normalized_path)); + return; + } - const std::string ciphertext = crypto_.EncryptForRecipient( - recipient_device.enc_x25519_pub, - message_text.toStdString()); + const QByteArray bytes = input.readAll(); + input.close(); + if (bytes.isEmpty()) { + emit ErrorOccurred("Cannot send an empty file."); + return; + } - MessageSendRequest request; - request.conversation_id = conversation_id; - request.envelope.version = 1; - request.envelope.alg = "libsodium-sealedbox-v1"; - request.envelope.recipient_device_id = recipient_device.id; - request.envelope.ciphertext_b64 = ciphertext; - request.envelope.aad_b64.clear(); - request.envelope.client_message_id = QUuid::createUuid().toString(QUuid::WithoutBraces).toStdString(); + if (!selected_group_conversation && !normalized_peer.trimmed().isEmpty()) { + try { + ResolveRecipientDevices(normalized_peer); + } catch (...) { + // Sending will surface canonical errors; policy fallback below remains available. + } + } - const auto send_op = [this, &request]() { - return api_client_.SendMessage(state_.base_url, RequireAccessToken(), request); - }; - const auto sent = CallWithAuthRetryOnce(send_op, [this]() { RefreshAccessToken(); }); + const std::string policy_key = normalized_peer.trimmed().toLower().toStdString(); + qint64 effective_inline_limit = 0; + qint64 max_ciphertext_bytes = 0; + QString policy_source = "fallback_local"; + + const auto policy_it = peer_attachment_policy_cache_.find(policy_key); + if (policy_it != peer_attachment_policy_cache_.end() && policy_it->second.attachment_inline_max_bytes > 0) { + effective_inline_limit = policy_it->second.attachment_inline_max_bytes; + max_ciphertext_bytes = policy_it->second.max_ciphertext_bytes; + policy_source = QString::fromStdString(policy_it->second.source); + } else if (local_attachment_policy_cache_.has_value() && + local_attachment_policy_cache_->attachment_inline_max_bytes > 0) { + effective_inline_limit = local_attachment_policy_cache_->attachment_inline_max_bytes; + max_ciphertext_bytes = local_attachment_policy_cache_->max_ciphertext_bytes; + policy_source = "fallback_local"; + } else { + effective_inline_limit = kConservativeInlineAttachmentFallbackBytes; + max_ciphertext_bytes = kConservativeInlineAttachmentFallbackBytes + (4 * 1024 * 1024); + policy_source = "fallback_local"; + } + + if (bytes.size() > effective_inline_limit) { + const QString source_label = policy_source == "remote" + ? "remote federation policy" + : (policy_source == "local" ? "local federation policy" : "local fallback policy"); + emit ErrorOccurred( + QString("File is too large for encrypted inline share (%1 max, source: %2).") + .arg(FormatBytesHuman(effective_inline_limit), source_label)); + return; + } - selected_conversation_id_ = conversation_id; - state_.MarkMessageSeen(sent.message.id); + const QFileInfo info(normalized_path); + QJsonObject payload; + payload.insert("name", info.fileName()); + payload.insert("size", static_cast(bytes.size())); + payload.insert("data_b64", QString::fromLatin1(bytes.toBase64())); + const QByteArray serialized = QJsonDocument(payload).toJson(QJsonDocument::Compact); + const QString marker = + QString("%1%2").arg(kFileMessagePrefix, QString::fromLatin1(serialized.toBase64())); + const qint64 marker_bytes = marker.toUtf8().size(); + const qint64 projected_ciphertext_bytes = marker_bytes + kSealedBoxOverheadBytes; + if (max_ciphertext_bytes > 0 && projected_ciphertext_bytes > max_ciphertext_bytes) { + const QString source_label = policy_source == "remote" + ? "remote federation policy" + : (policy_source == "local" ? "local federation policy" : "local fallback policy"); + emit ErrorOccurred( + QString("Encrypted file envelope exceeds ciphertext policy (%1 max, source: %2).") + .arg(FormatBytesHuman(max_ciphertext_bytes), source_label)); + return; + } - LocalMessage local; - local.id = sent.message.id; - local.conversation_id = conversation_id; - local.sender_user_id = state_.user.id; - local.created_at = sent.message.created_at; - local.rendered_text = RenderMessage(sent.message, message_text.toStdString()).toStdString(); - local.plaintext = message_text.toStdString(); - state_.local_messages[selected_conversation_id_].push_back(local); + SendMessageToPeer(selected_group_conversation ? QString() : normalized_peer, marker); +} - UpsertConversationMeta( - selected_conversation_id_, - normalized_peer.contains('@') ? normalized_peer.section('@', 0, 0) : normalized_peer, - normalized_peer.contains('@') ? normalized_peer : QString(), - MessagePreview(message_text), - QString::fromStdString(sent.message.created_at)); +void ApplicationController::SetPresenceStatus(const QString& status) { + const QString normalized = NormalizePresenceValue(status); + user_presence_status_ = normalized; + state_.social_preferences.presence_status = normalized.toStdString(); + PersistState(); + emit UserPresenceChanged(user_presence_status_); + + if (!state_.has_user || !state_.has_device) { + return; + } + try { + PresenceSetRequest request; + request.status = normalized.toStdString(); + const auto op = [this, &request]() { + return api_client_.SetPresenceStatus(state_.base_url, RequireAccessToken(), request); + }; + const auto response = CallWithAuthRetryOnce(op, [this]() { RefreshAccessToken(); }); + user_presence_status_ = NormalizePresenceValue(QString::fromStdString(response.status)); + state_.social_preferences.presence_status = user_presence_status_.toStdString(); PersistState(); + emit UserPresenceChanged(user_presence_status_); + RefreshPresenceCache(); RefreshConversationList(); - emit MessageSendSucceeded( - QString::fromStdString(selected_conversation_id_), - QString::fromStdString(sent.message.id)); - emit ConversationSelected( - QString::fromStdString(selected_conversation_id_), - RenderThread(selected_conversation_id_)); } catch (const std::exception& ex) { - const QString line = QString("Send message failed: %1").arg(ex.what()); + const QString line = QString("Set presence failed: %1").arg(ex.what()); RecordDiagnostic(line); emit ErrorOccurred(line); } @@ -749,6 +2405,8 @@ void ApplicationController::StartVoiceCall() { VoiceCallOffer offer; offer.conversation_id = selected_conversation_id_; ws_client_.SendCallOffer(offer); + call_initiated_locally_ = true; + pending_outgoing_end_request_ = false; QString peer_user; const auto meta_it = state_.conversation_meta.find(selected_conversation_id_); @@ -804,13 +2462,30 @@ void ApplicationController::EndVoiceCall() { if (!call_id.isEmpty()) { VoiceCallEnd end_request; end_request.call_id = call_id.toStdString(); - end_request.reason = "ended"; + end_request.reason = "left"; ws_client_.SendCallEnd(end_request); - } - - if (call_id.isEmpty()) { + pending_outgoing_end_request_ = false; + } else if (IsCallState("outgoing_ringing")) { + // Ringing call id arrives asynchronously; cancel as soon as server provides it. + pending_outgoing_end_request_ = true; + QTimer::singleShot(3000, this, [this]() { + if (!pending_outgoing_end_request_) { + return; + } + if (!call_state_.call_id.trimmed().isEmpty()) { + return; + } + if (!IsCallState("ending") && !IsCallState("outgoing_ringing")) { + return; + } + pending_outgoing_end_request_ = false; + StopAudioEngine(); + TransitionCallState("idle", QString(), QString(), QString(), "left"); + }); + } else { + pending_outgoing_end_request_ = false; StopAudioEngine(); - TransitionCallState("idle", QString(), QString(), QString(), "ended"); + TransitionCallState("idle", QString(), QString(), QString(), "left"); return; } @@ -819,7 +2494,7 @@ void ApplicationController::EndVoiceCall() { call_state_.call_id, call_state_.conversation_id, call_state_.peer_user_id, - "ending"); + "leaving"); } void ApplicationController::SetCallMuted(bool muted) { @@ -1038,9 +2713,11 @@ void ApplicationController::ResetLocalState() { emit AuthStateChanged(false, QString()); emit DeviceStateChanged(false); + emit AccountDevicesChanged(std::vector{}); emit ConversationListChanged(std::vector{}); emit ConversationSelected(QString(), {}); emit ConnectionStatusChanged(connection_status_); + emit UserPresenceChanged("offline"); LoadAudioDevices(); } @@ -1136,6 +2813,15 @@ std::string ApplicationController::RequireAccessToken() { return value.value(); } +std::string ApplicationController::RequireBootstrapToken() { + std::string error; + const auto value = secret_store_.GetSecret(SecretKey("bootstrap_token"), &error); + if (!value.has_value()) { + throw std::runtime_error("Bootstrap token unavailable"); + } + return value.value(); +} + std::string ApplicationController::RequireRefreshToken() { std::string error; const auto value = secret_store_.GetSecret(SecretKey("refresh_token"), &error); @@ -1145,7 +2831,20 @@ std::string ApplicationController::RequireRefreshToken() { return value.value(); } +void ApplicationController::SaveBootstrapToken(const TokenBundle& tokens) { + if (tokens.bootstrap_token.empty()) { + throw std::runtime_error("Bootstrap token missing"); + } + std::string error; + if (!secret_store_.SetSecret(SecretKey("bootstrap_token"), tokens.bootstrap_token, &error)) { + throw std::runtime_error(error); + } +} + void ApplicationController::SaveTokenPair(const TokenBundle& tokens) { + if (tokens.access_token.empty() || tokens.refresh_token.empty()) { + throw std::runtime_error("Access/refresh token pair missing"); + } std::string error; if (!secret_store_.SetSecret(SecretKey("access_token"), tokens.access_token, &error)) { throw std::runtime_error(error); @@ -1153,6 +2852,7 @@ void ApplicationController::SaveTokenPair(const TokenBundle& tokens) { if (!secret_store_.SetSecret(SecretKey("refresh_token"), tokens.refresh_token, &error)) { throw std::runtime_error(error); } + secret_store_.DeleteSecret(SecretKey("bootstrap_token"), &error); } void ApplicationController::RefreshAccessToken() { @@ -1175,10 +2875,154 @@ void ApplicationController::StopRealtime() { connection_status_ = "Disconnected"; } +void ApplicationController::EncryptPlaintextCacheInState() { + if (!state_.has_device || state_.device.enc_x25519_pub.empty()) { + return; + } + + for (auto& [conversation_id, thread] : state_.local_messages) { + (void)conversation_id; + for (auto& message : thread) { + if (message.plaintext.empty()) { + continue; + } + try { + message.plaintext_cache_b64 = crypto_.EncryptForRecipient( + state_.device.enc_x25519_pub, + message.plaintext); + } catch (const std::exception&) { + message.plaintext_cache_b64.clear(); + } + } + } +} + +void ApplicationController::DecryptPlaintextCacheInState() { + if (!state_.has_device) { + return; + } + std::string error; + const auto private_key = secret_store_.GetSecret(SecretKey("enc_private"), &error); + if (!private_key.has_value()) { + return; + } + + for (auto& [conversation_id, thread] : state_.local_messages) { + (void)conversation_id; + for (auto& message : thread) { + if (!message.plaintext.empty() || message.plaintext_cache_b64.empty()) { + continue; + } + try { + message.plaintext = crypto_.DecryptWithPrivate( + private_key.value(), + message.plaintext_cache_b64); + } catch (const std::exception&) { + message.plaintext.clear(); + } + } + } +} + void ApplicationController::PersistState() { + EncryptPlaintextCacheInState(); state_store_.Save(state_); } +QString ApplicationController::NormalizePresenceStatus(const QString& status) const { + return NormalizePresenceValue(status); +} + +QString ApplicationController::ResolvePeerAddressForConversation(const std::string& conversation_id) const { + const auto conv_it = std::find_if( + state_.conversations.begin(), + state_.conversations.end(), + [&conversation_id](const ConversationOut& conv) { return conv.id == conversation_id; }); + if (conv_it != state_.conversations.end() && conv_it->conversation_type == "group") { + return {}; + } + + const auto meta_it = state_.conversation_meta.find(conversation_id); + if (meta_it != state_.conversation_meta.end() && !meta_it->second.peer_address.empty()) { + return QString::fromStdString(meta_it->second.peer_address).trimmed().toLower(); + } + + if (conv_it == state_.conversations.end()) { + return {}; + } + if (!conv_it->peer_address.empty()) { + return QString::fromStdString(conv_it->peer_address).trimmed().toLower(); + } + if (!conv_it->peer_username.empty()) { + const QString local_server = QString::fromStdString( + state_.user.home_server_onion.empty() ? ServerAuthority().toStdString() : state_.user.home_server_onion) + .trimmed() + .toLower(); + return QString("%1@%2") + .arg(QString::fromStdString(conv_it->peer_username).trimmed().toLower(), local_server); + } + return {}; +} + +QString ApplicationController::PresenceForPeerAddress(const QString& peer_address) const { + const std::string key = peer_address.trimmed().toLower().toStdString(); + if (key.empty()) { + return "offline"; + } + const auto it = peer_presence_status_by_address_.find(key); + if (it == peer_presence_status_by_address_.end()) { + return "offline"; + } + return NormalizePresenceValue(it->second); +} + +void ApplicationController::RefreshPresenceCache() { + if (!state_.has_user || !state_.has_device) { + peer_presence_status_by_address_.clear(); + return; + } + + std::vector peer_addresses; + peer_addresses.reserve(state_.conversations.size()); + std::unordered_set seen; + for (const auto& conversation : state_.conversations) { + if (state_.blocked_conversation_ids.contains(conversation.id)) { + continue; + } + if (conversation.conversation_type == "direct" && state_.dismissed_conversation_ids.contains(conversation.id)) { + continue; + } + const QString peer = ResolvePeerAddressForConversation(conversation.id); + if (peer.isEmpty()) { + continue; + } + const std::string key = peer.toStdString(); + if (seen.contains(key)) { + continue; + } + seen.insert(key); + peer_addresses.push_back(key); + } + + try { + PresenceResolveRequest request; + request.peer_addresses = peer_addresses; + const auto op = [this, &request]() { + return api_client_.ResolvePresence(state_.base_url, RequireAccessToken(), request); + }; + const auto response = CallWithAuthRetryOnce(op, [this]() { RefreshAccessToken(); }); + + peer_presence_status_by_address_.clear(); + for (const auto& peer : response.peers) { + peer_presence_status_by_address_[peer.peer_address] = + NormalizePresenceValue(QString::fromStdString(peer.status)); + } + emit UserPresenceChanged(user_presence_status_); + } catch (const std::exception& ex) { + RecordDiagnostic(QString("presence resolve failed: %1").arg(ex.what())); + } +} + void ApplicationController::RefreshConversationList() { std::vector items; items.reserve(state_.conversations.size()); @@ -1187,6 +3031,9 @@ void ApplicationController::RefreshConversationList() { if (state_.blocked_conversation_ids.contains(conv.id)) { continue; } + if (conv.conversation_type == "direct" && state_.dismissed_conversation_ids.contains(conv.id)) { + continue; + } if (pending_request_senders_.find(conv.id) != pending_request_senders_.end()) { continue; } @@ -1194,6 +3041,11 @@ void ApplicationController::RefreshConversationList() { ConversationListItemView item; item.id = QString::fromStdString(conv.id); item.title = FriendlyConversationTitle(conv.id); + item.conversation_type = QString::fromStdString(conv.conversation_type).trimmed().toLower(); + item.can_manage_members = conv.can_manage_members; + item.peer_address = ResolvePeerAddressForConversation(conv.id); + item.group_name = QString::fromStdString(conv.group_name); + item.member_count = conv.member_count; const auto meta_it = state_.conversation_meta.find(conv.id); if (meta_it != state_.conversation_meta.end() && !meta_it->second.last_preview.empty()) { @@ -1203,6 +3055,7 @@ void ApplicationController::RefreshConversationList() { } item.last_activity_at = LastActivityForConversation(conv.id); + item.status = PresenceForPeerAddress(ResolvePeerAddressForConversation(conv.id)); items.push_back(item); } @@ -1225,6 +3078,13 @@ std::vector ApplicationController::RenderThread(const std::st } QString peer_label = "Peer"; + const auto conv_it = std::find_if( + state_.conversations.begin(), + state_.conversations.end(), + [&conversation_id](const ConversationOut& conv) { return conv.id == conversation_id; }); + if (conv_it != state_.conversations.end() && conv_it->conversation_type == "group") { + peer_label = "Member"; + } const auto meta_it = state_.conversation_meta.find(conversation_id); if (meta_it != state_.conversation_meta.end()) { if (!meta_it->second.peer_username.empty()) { @@ -1299,6 +3159,70 @@ void ApplicationController::RecordDiagnostic(const QString& line) { } } +void ApplicationController::UploadCurrentDevicePrekeys() { + if (!state_.has_user || !state_.has_device) { + return; + } + + std::string error; + const auto ik_private = secret_store_.GetSecret(SecretKey("ik_private"), &error); + if (!ik_private.has_value()) { + RecordDiagnostic("prekey upload skipped: device signing key unavailable"); + return; + } + + if (state_.device.enc_x25519_pub.empty()) { + RecordDiagnostic("prekey upload skipped: device DH public key missing"); + return; + } + + const int signed_prekey_id = 1; + const QString expires_qt = QDateTime::currentDateTimeUtc().addDays(30).toString("yyyy-MM-ddTHH:mm:ss+00:00"); + const std::string expires_at = expires_qt.toStdString(); + const std::string canonical = CanonicalSignedPrekeyString( + state_.device.id, + signed_prekey_id, + state_.device.enc_x25519_pub, + expires_at); + const std::string signature = crypto_.SignDetached(ik_private.value(), canonical); + + PrekeyUploadRequest request; + request.signed_prekey.key_id = signed_prekey_id; + request.signed_prekey.pub_x25519_b64 = state_.device.enc_x25519_pub; + request.signed_prekey.sig_by_device_sign_key_b64 = signature; + request.signed_prekey.expires_at = expires_at; + + try { + const auto op = [this, &request]() { + return api_client_.UploadPrekeys(state_.base_url, RequireAccessToken(), request); + }; + const auto response = CallWithAuthRetryOnce(op, [this]() { RefreshAccessToken(); }); + RecordDiagnostic( + QString("prekeys uploaded signed_key_id=%1 accepted_opk=%2") + .arg(response.uploaded_signed_prekey_key_id) + .arg(response.accepted_one_time_prekeys)); + } catch (const std::exception& ex) { + RecordDiagnostic(QString("prekey upload failed: %1").arg(ex.what())); + } +} + +bool ApplicationController::PreferRatchetV2b1() const { + return EnvFlagEnabled("BLACKWIRE_PREFER_RATCHET_V2B1", true); +} + +bool ApplicationController::DeviceSupportsMessageMode(const DeviceOut& device, const std::string& mode) const { + if (mode.empty()) { + return false; + } + if (device.supported_message_modes.empty()) { + return mode == "sealedbox_v0_2a"; + } + return std::find( + device.supported_message_modes.begin(), + device.supported_message_modes.end(), + mode) != device.supported_message_modes.end(); +} + bool ApplicationController::ConversationExists(const std::string& conversation_id) const { return FindConversation(conversation_id) != nullptr; } @@ -1320,6 +3244,7 @@ void ApplicationController::RevealConversation(const std::string& conversation_i } state_.blocked_conversation_ids.erase(conversation_id); + state_.dismissed_conversation_ids.erase(conversation_id); const auto* conversation = FindConversation(conversation_id); if (conversation != nullptr && (!conversation->peer_username.empty() || !conversation->peer_address.empty())) { @@ -1377,6 +3302,22 @@ QString ApplicationController::MessagePreview(const QString& message) const { } QString ApplicationController::FriendlyConversationTitle(const std::string& conversation_id) const { + const auto conv_it = std::find_if( + state_.conversations.begin(), + state_.conversations.end(), + [&conversation_id](const ConversationOut& conv) { return conv.id == conversation_id; }); + if (conv_it != state_.conversations.end() && conv_it->conversation_type == "group") { + const QString group_name = QString::fromStdString(conv_it->group_name).trimmed(); + if (!group_name.isEmpty()) { + return group_name; + } + const QString group_uid = QString::fromStdString(conv_it->group_uid).trimmed(); + if (!group_uid.isEmpty()) { + return QString("Group %1").arg(group_uid.left(8)); + } + return QString("Group %1").arg(QString::fromStdString(conversation_id).left(8)); + } + const auto meta_it = state_.conversation_meta.find(conversation_id); if (meta_it != state_.conversation_meta.end()) { if (!meta_it->second.peer_address.empty()) { @@ -1411,6 +3352,167 @@ QString ApplicationController::LastActivityForConversation(const std::string& co return {}; } +QString ApplicationController::FormatCallDuration(qint64 duration_ms) const { + const qint64 total_seconds = std::max(0, duration_ms / 1000); + const qint64 hours = total_seconds / 3600; + const qint64 minutes = (total_seconds % 3600) / 60; + const qint64 seconds = total_seconds % 60; + if (hours > 0) { + return QString("%1 hours").arg(hours); + } + if (minutes > 0) { + return QString("%1 minutes").arg(minutes); + } + return QString("%1 seconds").arg(seconds); +} + +std::vector ApplicationController::BuildDirectCallParticipants(const QString& peer_user_id) const { + std::vector participants; + participants.reserve(2); + + const QString local_server = QString::fromStdString( + state_.user.home_server_onion.empty() + ? ServerAuthority().toStdString() + : state_.user.home_server_onion) + .trimmed() + .toLower(); + const QString self_address = QString("%1@%2") + .arg(QString::fromStdString(state_.user.username).trimmed().toLower(), local_server) + .trimmed() + .toLower(); + participants.push_back(CallParticipantView{ + self_address, + "You", + true, + }); + + const QString peer = peer_user_id.trimmed().toLower(); + if (peer.isEmpty()) { + return participants; + } + + QString peer_label = peer; + if (peer.contains('@')) { + peer_label = peer.section('@', 0, 0).trimmed(); + } + if (peer_label.isEmpty()) { + peer_label = "Peer"; + } + + participants.push_back(CallParticipantView{ + peer, + peer_label, + false, + }); + return participants; +} + +void ApplicationController::AppendCallHistoryEntry(const QString& reason) { + const QString previous_state = call_state_.state.trimmed().toLower(); + const QString conversation_id = call_state_.conversation_id.trimmed(); + if (conversation_id.isEmpty()) { + return; + } + + const QString normalized_reason = reason.trimmed().toLower(); + if (normalized_reason == "logged_out" || normalized_reason == "reset") { + return; + } + + const qint64 now_ms = QDateTime::currentMSecsSinceEpoch(); + qint64 started_ms = call_active_started_at_ms_ > 0 ? call_active_started_at_ms_ : call_started_at_ms_; + if (started_ms <= 0) { + started_ms = now_ms; + } + + QString peer = call_state_.peer_user_id.trimmed().isEmpty() + ? FriendlyConversationTitle(conversation_id.toStdString()) + : call_state_.peer_user_id.trimmed(); + if (peer.contains('@')) { + peer = peer.section('@', 0, 0); + } + if (peer.trimmed().isEmpty()) { + peer = "Peer"; + } + const QString duration_text = FormatCallDuration(std::max(0, now_ms - started_ms)); + + QString text; + if (previous_state == "incoming_ringing" && normalized_reason == "missed") { + text = QString("You missed a call from %1 that lasted %2.").arg(peer, duration_text); + } else if (call_initiated_locally_) { + text = QString("You started a call with %1 that lasted %2.").arg(peer, duration_text); + } else { + text = QString("%1 started a call that lasted %2.").arg(peer, duration_text); + } + + LocalMessage local; + local.id = QUuid::createUuid().toString(QUuid::WithoutBraces).toStdString(); + local.conversation_id = conversation_id.toStdString(); + local.sender_user_id.clear(); + local.sender_address.clear(); + local.created_at = QDateTime::currentDateTimeUtc().toString(Qt::ISODateWithMs).toStdString(); + local.rendered_text = text.toStdString(); + local.plaintext = text.toStdString(); + state_.local_messages[conversation_id.toStdString()].push_back(local); + + UpsertConversationMeta( + conversation_id.toStdString(), + QString(), + QString(), + text, + QString::fromStdString(local.created_at)); + + PersistState(); + RefreshConversationList(); + + const auto thread = RenderThread(conversation_id.toStdString()); + if (selected_conversation_id_ == conversation_id.toStdString()) { + emit ConversationSelected(conversation_id, thread); + } else if (!thread.empty()) { + emit IncomingMessage(conversation_id, thread.back()); + } +} + +void ApplicationController::AppendGroupRenameHistoryEntry( + const QString& conversation_id, + const QString& actor_address, + const QString& group_name, + const QString& dedupe_suffix) { + const QString conversation = conversation_id.trimmed(); + const QString next_name = group_name.trimmed(); + if (conversation.isEmpty() || next_name.isEmpty()) { + return; + } + + const QString suffix = dedupe_suffix.trimmed().isEmpty() ? next_name.toLower() : dedupe_suffix.trimmed().toLower(); + const std::string synthetic_id = + QString("sys-group-rename:%1:%2").arg(conversation, suffix).toStdString(); + if (!state_.MarkMessageSeen(synthetic_id)) { + return; + } + + const QString actor_username = UsernameFromAddress(actor_address); + const QString text = QString("%1 changed group DM name to %2") + .arg(actor_username, next_name); + + LocalMessage local; + local.id = synthetic_id; + local.conversation_id = conversation.toStdString(); + local.sender_user_id.clear(); + local.sender_address = actor_address.trimmed().toLower().toStdString(); + local.created_at = QDateTime::currentDateTimeUtc().toString(Qt::ISODateWithMs).toStdString(); + local.rendered_text = text.toStdString(); + local.plaintext = text.toStdString(); + state_.local_messages[conversation.toStdString()].push_back(local); + + UpsertConversationMeta( + conversation.toStdString(), + QString(), + QString(), + text, + QString::fromStdString(local.created_at)); +} + QString ApplicationController::PreferredInputDeviceId() const { const QString configured = QString::fromStdString(state_.audio_preferences.preferred_input_device_id).trimmed(); if (!configured.isEmpty()) { @@ -1441,15 +3543,41 @@ void ApplicationController::TransitionCallState( const QString& conversation_id, const QString& peer_user_id, const QString& reason) { + const QString previous_state = call_state_.state.trimmed().toLower(); + const QString previous_call_id = call_state_.call_id; + const QString next_state = state.trimmed().isEmpty() ? "idle" : state.trimmed().toLower(); + if (next_state == "idle" && previous_state != "idle") { + AppendCallHistoryEntry(reason); + } + + if (next_state == "incoming_ringing" || next_state == "outgoing_ringing") { + call_started_at_ms_ = QDateTime::currentMSecsSinceEpoch(); + call_active_started_at_ms_ = 0; + } else if (next_state == "active") { + if (call_started_at_ms_ <= 0) { + call_started_at_ms_ = QDateTime::currentMSecsSinceEpoch(); + } + call_active_started_at_ms_ = QDateTime::currentMSecsSinceEpoch(); + } else if (next_state == "idle") { + call_started_at_ms_ = 0; + call_active_started_at_ms_ = 0; + } + call_state_.state = state.trimmed().isEmpty() ? "idle" : state.trimmed(); call_state_.call_id = call_id; call_state_.conversation_id = conversation_id; call_state_.peer_user_id = peer_user_id; call_state_.reason = reason; + if (next_state != "active" || (!previous_call_id.isEmpty() && previous_call_id != call_id)) { + call_state_.participants.clear(); + } if (IsCallState("idle") || IsCallState("incoming_ringing") || IsCallState("outgoing_ringing")) { call_state_.muted = false; } + if (IsCallState("idle")) { + call_initiated_locally_ = false; + } EmitCallState(); } @@ -1511,27 +3639,39 @@ std::optional ApplicationController::FindConversationIdForPeer( if (key.empty()) { return std::nullopt; } + const bool key_has_server = key.find('@') != std::string::npos; - const auto key_username = key.find('@') != std::string::npos ? key.substr(0, key.find('@')) : key; - - for (const auto& [conversation_id, meta] : state_.conversation_meta) { - const bool address_match = !meta.peer_address.empty() && meta.peer_address == key; - const bool username_match = !meta.peer_username.empty() && meta.peer_username == key; - const bool loose_username_match = - !meta.peer_username.empty() && meta.peer_username == key_username; - if (!address_match && !username_match && !loose_username_match) { + std::optional exact_username_match; + std::optional local_prefix_match; + for (const auto& conversation : state_.conversations) { + const auto meta_it = state_.conversation_meta.find(conversation.id); + if (meta_it == state_.conversation_meta.end()) { continue; } - - const auto conv_it = std::find_if( - state_.conversations.begin(), - state_.conversations.end(), - [&conversation_id](const ConversationOut& conv) { return conv.id == conversation_id; }); - if (conv_it != state_.conversations.end()) { - return conversation_id; + const auto& meta = meta_it->second; + if (!meta.peer_address.empty() && meta.peer_address == key) { + return conversation.id; + } + if (!key_has_server && !meta.peer_username.empty() && meta.peer_username == key) { + if (!exact_username_match.has_value()) { + exact_username_match = conversation.id; + } + continue; + } + if (!key_has_server && !meta.peer_address.empty() && meta.peer_address.rfind(key + "@", 0) == 0) { + if (!local_prefix_match.has_value()) { + local_prefix_match = conversation.id; + } } } + if (exact_username_match.has_value()) { + return exact_username_match; + } + if (local_prefix_match.has_value()) { + return local_prefix_match; + } + return std::nullopt; } @@ -1591,7 +3731,7 @@ std::string ApplicationController::ResolveConversationIdForPeer(const QString& n return conversation.id; } -DeviceOut ApplicationController::ResolveRecipientDevice(const QString& normalized_peer) { +std::vector ApplicationController::ResolveRecipientDevices(const QString& normalized_peer) { const std::string key = normalized_peer.trimmed().toLower().toStdString(); if (key.empty()) { throw std::runtime_error("Peer address is required"); @@ -1606,8 +3746,37 @@ DeviceOut ApplicationController::ResolveRecipientDevice(const QString& normalize return api_client_.GetUserDevice(state_.base_url, RequireAccessToken(), key); }; const auto user_device = CallWithAuthRetryOnce(get_device, [this]() { RefreshAccessToken(); }); - peer_device_cache_[key] = user_device.device; - return user_device.device; + AttachmentPolicyCacheEntry policy_entry; + policy_entry.attachment_inline_max_bytes = std::max(0, user_device.attachment_inline_max_bytes); + policy_entry.max_ciphertext_bytes = std::max(0, user_device.max_ciphertext_bytes); + policy_entry.source = user_device.attachment_policy_source.empty() ? "local" : user_device.attachment_policy_source; + peer_attachment_policy_cache_[key] = policy_entry; + if (policy_entry.source == "local" && policy_entry.attachment_inline_max_bytes > 0 && + policy_entry.max_ciphertext_bytes > 0) { + local_attachment_policy_cache_ = policy_entry; + } + + if (user_device.devices.empty() && (!user_device.device.id.empty() || !user_device.device.device_uid.empty())) { + peer_device_cache_[key] = {user_device.device}; + } else { + peer_device_cache_[key] = user_device.devices; + } + return peer_device_cache_[key]; +} + +std::vector ApplicationController::ResolveOwnActiveDevices() { + const auto operation = [this]() { + return api_client_.ListDevices(state_.base_url, RequireAccessToken()); + }; + auto devices = CallWithAuthRetryOnce>(operation, [this]() { RefreshAccessToken(); }); + std::vector active; + for (const auto& device : devices) { + if (device.status == "revoked") { + continue; + } + active.push_back(device); + } + return active; } void ApplicationController::UpsertConversationMeta( @@ -1635,11 +3804,19 @@ void ApplicationController::ClearInMemoryState() { state_ = ClientState{}; state_.base_url = "http://localhost:8000"; peer_device_cache_.clear(); + peer_attachment_policy_cache_.clear(); + local_attachment_policy_cache_.reset(); + peer_presence_status_by_address_.clear(); pending_request_messages_.clear(); pending_request_senders_.clear(); selected_conversation_id_.clear(); connection_status_ = "Disconnected"; + user_presence_status_ = "active"; call_state_ = CallStateView{}; + call_initiated_locally_ = false; + call_started_at_ms_ = 0; + call_active_started_at_ms_ = 0; + pending_outgoing_end_request_ = false; audio_sequence_ = 0; } diff --git a/client-cpp-gui/src/crypto/sodium_crypto_service.cpp b/client-cpp-gui/src/crypto/sodium_crypto_service.cpp index 1d8191c..eb44c87 100644 --- a/client-cpp-gui/src/crypto/sodium_crypto_service.cpp +++ b/client-cpp-gui/src/crypto/sodium_crypto_service.cpp @@ -1,5 +1,7 @@ #include "blackwire/crypto/sodium_crypto_service.hpp" +#include +#include #include #include @@ -85,6 +87,60 @@ std::string SodiumCryptoService::DecryptWithPrivate( return std::string(reinterpret_cast(plaintext.data()), plaintext.size()); } +std::string SodiumCryptoService::SignDetached( + const std::string& ed25519_private_key_b64, + const std::string& message) { + const auto private_key = DecodeBase64(ed25519_private_key_b64); + if (private_key.size() != crypto_sign_SECRETKEYBYTES) { + throw std::runtime_error("Ed25519 private key has invalid length"); + } + + unsigned char signature[crypto_sign_BYTES]; + if (crypto_sign_detached( + signature, + nullptr, + reinterpret_cast(message.data()), + message.size(), + private_key.data()) != 0) { + throw std::runtime_error("Unable to sign payload"); + } + + return EncodeBase64(signature, crypto_sign_BYTES); +} + +bool SodiumCryptoService::VerifyDetached( + const std::string& ed25519_public_key_b64, + const std::string& message, + const std::string& signature_b64) { + const auto public_key = DecodeBase64(ed25519_public_key_b64); + const auto signature = DecodeBase64(signature_b64); + if (public_key.size() != crypto_sign_PUBLICKEYBYTES || signature.size() != crypto_sign_BYTES) { + return false; + } + + const int verified = crypto_sign_verify_detached( + signature.data(), + reinterpret_cast(message.data()), + message.size(), + public_key.data()); + return verified == 0; +} + +std::string SodiumCryptoService::Sha256(const std::string& data) { + unsigned char digest[crypto_hash_sha256_BYTES]; + crypto_hash_sha256( + digest, + reinterpret_cast(data.data()), + data.size()); + + std::ostringstream out; + out << std::hex << std::setfill('0'); + for (unsigned char byte : digest) { + out << std::setw(2) << static_cast(byte); + } + return out.str(); +} + std::string SodiumCryptoService::EncodeBase64(const unsigned char* bytes, std::size_t length) { const std::size_t out_size = sodium_base64_ENCODED_LEN(length, sodium_base64_VARIANT_ORIGINAL); std::string out(out_size, '\0'); diff --git a/client-cpp-gui/src/smoke/smoke_runner.cpp b/client-cpp-gui/src/smoke/smoke_runner.cpp index fa85bcf..58b7378 100644 --- a/client-cpp-gui/src/smoke/smoke_runner.cpp +++ b/client-cpp-gui/src/smoke/smoke_runner.cpp @@ -1,6 +1,7 @@ #include "blackwire/smoke/smoke_runner.hpp" #include +#include #include #include #include @@ -14,6 +15,37 @@ namespace blackwire { +namespace { + +std::string CanonicalMessageSignature( + const std::string& sender_address, + const std::string& sender_device_uid, + const std::string& recipient_address, + const std::string& recipient_device_uid, + const std::string& client_message_id, + long long sent_at_ms, + const std::string& sender_prev_hash, + const std::string& sender_chain_hash, + const std::string& ciphertext_hash, + const std::string& aad_hash) { + return sender_address + "\n" + sender_device_uid + "\n" + recipient_address + "\n" + recipient_device_uid + "\n" + + client_message_id + "\n" + std::to_string(sent_at_ms) + "\n" + sender_prev_hash + "\n" + sender_chain_hash + + "\n" + ciphertext_hash + "\n" + aad_hash; +} + +std::string AggregateChainHash( + ICryptoService& crypto, + const std::string& sender_prev_hash, + const std::string& client_message_id, + long long sent_at_ms, + const std::string& hash_material) { + const std::string aggregate_hash = crypto.Sha256(hash_material); + return crypto.Sha256( + sender_prev_hash + "\n" + client_message_id + "\n" + std::to_string(sent_at_ms) + "\n" + aggregate_hash); +} + +} // namespace + int SmokeRunner::Run(const QString& base_url) { try { QtApiClient api; @@ -35,34 +67,74 @@ int SmokeRunner::Run(const QString& base_url) { alice_device_req.label = "smoke-alice"; alice_device_req.ik_ed25519_pub = alice_keys.ik_ed25519_public_b64; alice_device_req.enc_x25519_pub = alice_keys.enc_x25519_public_b64; - api.RegisterDevice(base_url.toStdString(), alice_auth.tokens.access_token, alice_device_req); + const auto alice_device_auth = api.RegisterDevice( + base_url.toStdString(), + alice_auth.tokens.bootstrap_token, + alice_device_req); DeviceRegisterRequest bob_device_req; bob_device_req.label = "smoke-bob"; bob_device_req.ik_ed25519_pub = bob_keys.ik_ed25519_public_b64; bob_device_req.enc_x25519_pub = bob_keys.enc_x25519_public_b64; - const auto bob_device = api.RegisterDevice( + const auto bob_device_auth = api.RegisterDevice( base_url.toStdString(), - bob_auth.tokens.access_token, + bob_auth.tokens.bootstrap_token, bob_device_req); + const auto alice_access = alice_device_auth.tokens.access_token; + const auto bob_access = bob_device_auth.tokens.access_token; + const auto alice_device_uid = alice_device_auth.tokens.device_uid; + const auto conversation = api.CreateDm( base_url.toStdString(), - alice_auth.tokens.access_token, + alice_access, "", bob_name); + const auto bob_lookup = api.GetUserDevice( + base_url.toStdString(), + alice_access, + bob_name + "@local.invalid"); + if (bob_lookup.devices.empty()) { + throw std::runtime_error("Smoke test: no bob devices resolved"); + } + const auto bob_device = bob_lookup.devices.front(); + const std::string plaintext = "blackwire-smoke-message"; const std::string ciphertext = crypto.EncryptForRecipient(bob_device.enc_x25519_pub, plaintext); const std::string client_message_id = QUuid::createUuid().toString(QUuid::WithoutBraces).toStdString(); + const long long sent_at_ms = QDateTime::currentMSecsSinceEpoch(); + const std::string sender_prev_hash; + const std::string ciphertext_hash = crypto.Sha256(QByteArray::fromBase64(QByteArray::fromStdString(ciphertext)).toStdString()); + const std::string aad_hash = crypto.Sha256(std::string()); + const std::string chain_material = bob_device.device_uid + ":" + ciphertext_hash + ":" + aad_hash; + const std::string sender_chain_hash = AggregateChainHash( + crypto, + sender_prev_hash, + client_message_id, + sent_at_ms, + chain_material); + const std::string sender_address = alice_name + "@local.invalid"; + const std::string canonical = CanonicalMessageSignature( + sender_address, + alice_device_uid, + bob_name + "@local.invalid", + bob_device.device_uid, + client_message_id, + sent_at_ms, + sender_prev_hash, + sender_chain_hash, + ciphertext_hash, + aad_hash); + const std::string signature = crypto.SignDetached(alice_keys.ik_ed25519_private_b64, canonical); bool received = false; bool decrypted_ok = false; ws.SetHandlers( [&](const WsEventMessageNew& event) { - ws.SendAck(event.message.id); + ws.SendAck(event.copy_id.empty() ? event.message.id : event.copy_id); try { const auto decrypted = crypto.DecryptWithPrivate( bob_keys.enc_x25519_private_b64, @@ -75,6 +147,7 @@ int SmokeRunner::Run(const QString& base_url) { } }, [&](const WsEventCallIncoming&) {}, + [&](const WsEventCallGroupState&) {}, [&](const WsEventCallRinging&) {}, [&](const WsEventCallAccepted&) {}, [&](const WsEventCallRejected&) {}, @@ -82,22 +155,33 @@ int SmokeRunner::Run(const QString& base_url) { [&](const WsEventCallEnded&) {}, [&](const WsEventCallAudio&) {}, [&](const WsEventCallError&) {}, + [&](const WsEventCallWebRtcOffer&) {}, + [&](const WsEventCallWebRtcAnswer&) {}, + [&](const WsEventCallWebRtcIce&) {}, + [&](const WsEventGroupRenamed&) {}, [&](const std::string& error) { std::cerr << "WS error: " << error << '\n'; }, [&](bool) {}); - ws.Connect(base_url.toStdString(), bob_auth.tokens.access_token); + ws.Connect(base_url.toStdString(), bob_access); MessageSendRequest request; request.conversation_id = conversation.id; - request.envelope.version = 1; - request.envelope.alg = "libsodium-sealedbox-v1"; - request.envelope.recipient_device_id = bob_device.id; - request.envelope.ciphertext_b64 = ciphertext; - request.envelope.client_message_id = client_message_id; - - api.SendMessage(base_url.toStdString(), alice_auth.tokens.access_token, request); + request.client_message_id = client_message_id; + request.sent_at_ms = sent_at_ms; + request.sender_prev_hash = sender_prev_hash; + request.sender_chain_hash = sender_chain_hash; + CipherEnvelope env; + env.recipient_user_address = bob_name + "@local.invalid"; + env.recipient_device_uid = bob_device.device_uid; + env.ciphertext_b64 = ciphertext; + env.aad_b64.clear(); + env.signature_b64 = signature; + env.sender_device_pubkey = alice_keys.ik_ed25519_public_b64; + request.envelopes = {env}; + + api.SendMessage(base_url.toStdString(), alice_access, request); QEventLoop loop; QTimer timer; diff --git a/client-cpp-gui/src/ui/chat_widget.cpp b/client-cpp-gui/src/ui/chat_widget.cpp index fc7e6e0..b11046c 100644 --- a/client-cpp-gui/src/ui/chat_widget.cpp +++ b/client-cpp-gui/src/ui/chat_widget.cpp @@ -1,9 +1,18 @@ #include "blackwire/ui/chat_widget.hpp" +#include +#include + #include +#include +#include #include +#include +#include #include #include +#include +#include #include #include #include @@ -19,6 +28,7 @@ #include #include #include +#include #include #include #include @@ -30,6 +40,94 @@ namespace blackwire { namespace { +const char* kFileMessagePrefix = "bwfile://v1:"; +constexpr int kRoleConversationId = Qt::UserRole; +constexpr int kRoleTitle = Qt::UserRole + 1; +constexpr int kRoleSubtitle = Qt::UserRole + 2; +constexpr int kRoleStatus = Qt::UserRole + 3; +constexpr int kRoleConversationType = Qt::UserRole + 4; +constexpr int kRoleCanManageMembers = Qt::UserRole + 5; +constexpr int kRolePeerAddress = Qt::UserRole + 6; +constexpr int kRoleGroupName = Qt::UserRole + 7; +constexpr int kRoleMemberCount = Qt::UserRole + 8; + +QString NormalizePresenceStatus(const QString& value) { + const QString normalized = value.trimmed().toLower(); + if (normalized == "active" || normalized == "inactive" || normalized == "offline" || normalized == "dnd") { + return normalized; + } + return "offline"; +} + +QString PresenceColorForStatus(const QString& value) { + const QString normalized = NormalizePresenceStatus(value); + if (normalized == "active") { + return "#2d7d46"; // green + } + if (normalized == "inactive") { + return "#d4a63c"; // yellow + } + if (normalized == "dnd") { + return "#9e2c31"; // red + } + return "#6a6f78"; // offline/default gray +} + +QString HumanFileSize(qint64 size_bytes) { + if (size_bytes >= 1024 * 1024) { + return QString("%1 MB").arg(QString::number(static_cast(size_bytes) / (1024.0 * 1024.0), 'f', 1)); + } + if (size_bytes >= 1024) { + return QString("%1 KB").arg(QString::number(static_cast(size_bytes) / 1024.0, 'f', 1)); + } + return QString("%1 bytes").arg(size_bytes); +} + +bool DecodeFileMessageBody( + const QString& body, + QString* file_name, + QByteArray* file_bytes, + qint64* file_size_bytes) { + if (!body.startsWith(kFileMessagePrefix, Qt::CaseInsensitive)) { + return false; + } + const QString encoded = body.mid(static_cast(strlen(kFileMessagePrefix))).trimmed(); + if (encoded.isEmpty()) { + return false; + } + const QByteArray decoded_json = QByteArray::fromBase64(encoded.toUtf8()); + if (decoded_json.isEmpty()) { + return false; + } + + QJsonParseError parse_error{}; + const QJsonDocument doc = QJsonDocument::fromJson(decoded_json, &parse_error); + if (parse_error.error != QJsonParseError::NoError || !doc.isObject()) { + return false; + } + const QJsonObject obj = doc.object(); + const QString candidate_name = obj.value("name").toString().trimmed(); + const QString data_b64 = obj.value("data_b64").toString(); + if (candidate_name.isEmpty() || data_b64.isEmpty()) { + return false; + } + const QByteArray bytes = QByteArray::fromBase64(data_b64.toUtf8()); + if (bytes.isEmpty()) { + return false; + } + + if (file_name != nullptr) { + *file_name = candidate_name; + } + if (file_bytes != nullptr) { + *file_bytes = bytes; + } + if (file_size_bytes != nullptr) { + *file_size_bytes = static_cast(obj.value("size").toDouble(static_cast(bytes.size()))); + } + return true; +} + QString FormatCallStatusText(const CallStateView& state) { const QString normalized = state.state.trimmed().toLower(); if (normalized == "active") { @@ -129,10 +227,21 @@ ChatWidget::ChatWidget(QWidget* parent) : QWidget(parent) { sidebar_layout->setContentsMargins(14, 14, 14, 14); sidebar_layout->setSpacing(10); - auto* sidebar_title = new QLabel("Direct Messages", sidebar); + auto* sidebar_title = new QLabel("Messages", sidebar); sidebar_title->setObjectName("dmSidebarTitle"); sidebar_layout->addWidget(sidebar_title); + auto* contacts_toggle_card = new QWidget(sidebar); + contacts_toggle_card->setObjectName("contactsToggleCard"); + auto* contacts_toggle_layout = new QHBoxLayout(contacts_toggle_card); + contacts_toggle_layout->setContentsMargins(10, 8, 10, 8); + contacts_toggle_layout->setSpacing(0); + contacts_button_ = new QPushButton("Contacts", contacts_toggle_card); + contacts_button_->setObjectName("secondaryButton"); + contacts_button_->setCheckable(true); + contacts_toggle_layout->addWidget(contacts_button_); + sidebar_layout->addWidget(contacts_toggle_card); + auto* peer_row = new QHBoxLayout(); peer_row->setSpacing(8); peer_input_ = new QLineEdit(sidebar); @@ -182,15 +291,37 @@ ChatWidget::ChatWidget(QWidget* parent) : QWidget(parent) { auto* header_row = new QHBoxLayout(); header_row->setSpacing(8); - thread_title_label_ = new QLabel("Select a conversation", content); + thread_title_label_ = new QLineEdit("Select a conversation", content); thread_title_label_->setObjectName("threadTitle"); + thread_title_label_->setReadOnly(true); + thread_title_label_->setFrame(false); + thread_title_label_->setFocusPolicy(Qt::NoFocus); + thread_title_label_->setCursor(Qt::ArrowCursor); + call_button_ = new QPushButton("Call", content); + call_button_->setObjectName("primaryButton"); + group_button_ = new QPushButton("Group", content); + group_button_->setObjectName("secondaryButton"); + invite_button_ = new QPushButton("Invite", content); + invite_button_->setObjectName("secondaryButton"); + presence_indicator_ = new QLabel(content); + presence_indicator_->setObjectName("presenceDot"); + presence_indicator_->setFixedSize(10, 10); + presence_indicator_->setProperty("state", "offline"); + presence_combo_ = new QComboBox(content); + presence_combo_->setObjectName("statusCombo"); + presence_combo_->addItem("Active", "active"); + presence_combo_->addItem("Inactive", "inactive"); + presence_combo_->addItem("Offline", "offline"); + presence_combo_->addItem("DND", "dnd"); status_label_ = new QLabel("Disconnected", content); status_label_->setObjectName("connectionPill"); status_label_->setProperty("state", "disconnected"); - call_button_ = new QPushButton("Call", content); - call_button_->setObjectName("primaryButton"); header_row->addWidget(thread_title_label_, 1); header_row->addWidget(call_button_); + header_row->addWidget(group_button_); + header_row->addWidget(invite_button_); + header_row->addWidget(presence_indicator_); + header_row->addWidget(presence_combo_); header_row->addWidget(status_label_); content_layout->addLayout(header_row); @@ -200,8 +331,14 @@ ChatWidget::ChatWidget(QWidget* parent) : QWidget(parent) { banner_label_->setProperty("severity", "info"); banner_label_->setVisible(false); content_layout->addWidget(banner_label_); + content_stack_ = new QStackedWidget(content); + + chat_panel_ = new QWidget(content_stack_); + auto* chat_layout = new QVBoxLayout(chat_panel_); + chat_layout->setContentsMargins(0, 0, 0, 0); + chat_layout->setSpacing(10); - call_panel_ = new QWidget(content); + call_panel_ = new QWidget(chat_panel_); call_panel_->setObjectName("callPanel"); auto* call_panel_layout = new QHBoxLayout(call_panel_); call_panel_layout->setContentsMargins(12, 12, 12, 12); @@ -222,10 +359,17 @@ ChatWidget::ChatWidget(QWidget* parent) : QWidget(parent) { call_status_label_ = new QLabel("Idle", call_panel_); call_status_label_->setObjectName("callStatusPill"); call_status_label_->setProperty("state", "idle"); + call_participants_panel_ = new QWidget(call_panel_); + call_participants_panel_->setObjectName("callParticipantsPanel"); + call_participants_layout_ = new QHBoxLayout(call_participants_panel_); + call_participants_layout_->setContentsMargins(0, 0, 0, 0); + call_participants_layout_->setSpacing(6); call_center_col->addWidget(call_panel_title_); call_center_col->addWidget(call_panel_subtitle_); call_center_col->addWidget(call_status_label_, 0, Qt::AlignLeft); + call_center_col->addWidget(call_participants_panel_, 0, Qt::AlignLeft); call_panel_layout->addLayout(call_center_col, 1); + call_participants_panel_->setVisible(false); auto* call_controls_row = new QHBoxLayout(); call_controls_row->setSpacing(8); @@ -244,9 +388,9 @@ ChatWidget::ChatWidget(QWidget* parent) : QWidget(parent) { call_panel_layout->addLayout(call_controls_row); call_panel_->setVisible(false); - content_layout->addWidget(call_panel_); + chat_layout->addWidget(call_panel_); - timeline_stack_ = new QStackedWidget(content); + timeline_stack_ = new QStackedWidget(chat_panel_); empty_state_label_ = new QLabel("Select a conversation to start chatting.", timeline_stack_); empty_state_label_->setObjectName("chatEmptyState"); empty_state_label_->setAlignment(Qt::AlignCenter); @@ -261,24 +405,48 @@ ChatWidget::ChatWidget(QWidget* parent) : QWidget(parent) { timeline_stack_->addWidget(empty_state_label_); timeline_stack_->addWidget(messages_list_); timeline_stack_->setCurrentWidget(empty_state_label_); - content_layout->addWidget(timeline_stack_, 1); + chat_layout->addWidget(timeline_stack_, 1); auto* compose_row = new QHBoxLayout(); compose_row->setSpacing(10); - compose_input_ = new QPlainTextEdit(content); + compose_input_ = new QPlainTextEdit(chat_panel_); compose_input_->setObjectName("composeInput"); compose_input_->setPlaceholderText("Message #dm"); compose_input_->setMaximumBlockCount(200); compose_input_->setFixedHeight(96); compose_input_->installEventFilter(this); - send_button_ = new QPushButton("Send", content); + attach_button_ = new QPushButton("Attach", chat_panel_); + attach_button_->setObjectName("secondaryButton"); + send_button_ = new QPushButton("Send", chat_panel_); send_button_->setObjectName("primaryButton"); send_button_->setEnabled(false); compose_row->addWidget(compose_input_, 1); + compose_row->addWidget(attach_button_); compose_row->addWidget(send_button_); - content_layout->addLayout(compose_row); + chat_layout->addLayout(compose_row); + + contacts_panel_ = new QWidget(content_stack_); + auto* contacts_layout = new QVBoxLayout(contacts_panel_); + contacts_layout->setContentsMargins(0, 0, 0, 0); + contacts_layout->setSpacing(8); + + auto* contacts_title = new QLabel("Contacts", contacts_panel_); + contacts_title->setObjectName("contactsSidebarTitle"); + contacts_layout->addWidget(contacts_title); + + contacts_panel_list_ = new QListWidget(contacts_panel_); + contacts_panel_list_->setObjectName("contactsList"); + contacts_panel_list_->setSelectionMode(QAbstractItemView::SingleSelection); + contacts_panel_list_->setVerticalScrollMode(QAbstractItemView::ScrollPerPixel); + contacts_panel_list_->setHorizontalScrollBarPolicy(Qt::ScrollBarAlwaysOff); + contacts_panel_list_->setSpacing(2); + contacts_layout->addWidget(contacts_panel_list_, 1); + + content_stack_->addWidget(chat_panel_); + content_stack_->addWidget(contacts_panel_); + content_layout->addWidget(content_stack_, 1); split->setStretchFactor(0, 0); split->setStretchFactor(1, 1); @@ -297,9 +465,45 @@ ChatWidget::ChatWidget(QWidget* parent) : QWidget(parent) { }); connect(new_chat_button_, &QPushButton::clicked, this, &ChatWidget::NewConversationRequested); + connect(new_chat_button_, &QPushButton::clicked, this, [this]() { + SetContactsMode(false); + }); connect(send_button_, &QPushButton::clicked, this, &ChatWidget::SendMessageRequested); + connect(attach_button_, &QPushButton::clicked, this, [this]() { + const QString file_path = QFileDialog::getOpenFileName(this, "Select file to share"); + if (!file_path.trimmed().isEmpty()) { + emit SendFileRequested(file_path.trimmed()); + } + }); connect(settings_button_, &QPushButton::clicked, this, &ChatWidget::SettingsRequested); connect(call_button_, &QPushButton::clicked, this, &ChatWidget::StartVoiceCallRequested); + connect(group_button_, &QPushButton::clicked, this, &ChatWidget::CreateGroupFromDmRequested); + connect(invite_button_, &QPushButton::clicked, this, &ChatWidget::GroupInviteDialogRequested); + connect(thread_title_label_, &QLineEdit::returnPressed, this, [this]() { + if (contacts_mode_active_) { + return; + } + const auto* selected = conversations_list_->currentItem(); + if (selected == nullptr) { + return; + } + const QString conversation_type = selected->data(kRoleConversationType).toString().trimmed().toLower(); + const bool can_manage_members = selected->data(kRoleCanManageMembers).toBool(); + if (conversation_type != "group" || !can_manage_members) { + return; + } + + const QString current_name = selected->data(kRoleGroupName).toString().trimmed(); + const QString typed_name = thread_title_label_->text().trimmed(); + if (typed_name.isEmpty()) { + thread_title_label_->setText(current_name.isEmpty() ? selected->data(kRoleTitle).toString() : current_name); + return; + } + if (!current_name.isEmpty() && typed_name == current_name) { + return; + } + emit GroupRenameRequested(typed_name); + }); connect(accept_call_button_, &QPushButton::clicked, this, &ChatWidget::AcceptVoiceCallRequested); connect(decline_call_button_, &QPushButton::clicked, this, &ChatWidget::RejectVoiceCallRequested); connect(end_call_button_, &QPushButton::clicked, this, &ChatWidget::EndVoiceCallRequested); @@ -321,12 +525,26 @@ ChatWidget::ChatWidget(QWidget* parent) : QWidget(parent) { SetSendEnabled(!ComposeText().trimmed().isEmpty()); }); + connect(presence_combo_, &QComboBox::currentIndexChanged, this, [this](int index) { + if (index < 0) { + return; + } + const QString status = NormalizePresenceStatus(presence_combo_->itemData(index).toString()); + UpdatePresenceIndicator(status); + emit UserStatusChanged(status); + }); + + connect(contacts_button_, &QPushButton::clicked, this, [this]() { + SetContactsMode(!contacts_mode_active_); + }); + connect(conversations_list_, &QListWidget::itemSelectionChanged, this, [this]() { + const QString id = SelectedConversation(); + SyncContactSelection(id); + SetContactsMode(false); RefreshConversationSelectionStyles(); UpdateThreadHeader(); UpdateCallControls(); - - const QString id = SelectedConversation(); if (id.isEmpty()) { SetTimelineHasMessages(false); return; @@ -337,6 +555,24 @@ ChatWidget::ChatWidget(QWidget* parent) : QWidget(parent) { emit ConversationSelected(id); }); + connect(contacts_panel_list_, &QListWidget::itemSelectionChanged, this, [this]() { + const auto* selected = contacts_panel_list_->currentItem(); + if (selected == nullptr) { + return; + } + const QString id = selected->data(kRoleConversationId).toString(); + if (id.isEmpty()) { + return; + } + SetSelectedConversation(id); + SetContactsMode(false); + messages_list_->clear(); + SetTimelineHasMessages(false); + emit ConversationSelected(id); + }); + + SetContactsMode(true); + UpdatePresenceIndicator("offline"); UpdateCallControls(); } @@ -353,13 +589,15 @@ QString ChatWidget::SelectedConversation() const { if (item == nullptr) { return {}; } - return item->data(Qt::UserRole).toString(); + return item->data(kRoleConversationId).toString(); } void ChatWidget::SetSelectedConversation(const QString& conversation_id) { QSignalBlocker signal_blocker(conversations_list_); + QSignalBlocker contact_signal_blocker(contacts_panel_list_); if (conversation_id.isEmpty()) { conversations_list_->setCurrentItem(nullptr); + contacts_panel_list_->setCurrentItem(nullptr); RefreshConversationSelectionStyles(); UpdateThreadHeader(); UpdateCallControls(); @@ -368,19 +606,30 @@ void ChatWidget::SetSelectedConversation(const QString& conversation_id) { for (int index = 0; index < conversations_list_->count(); ++index) { auto* item = conversations_list_->item(index); - if (item->data(Qt::UserRole).toString() != conversation_id) { + if (item->data(kRoleConversationId).toString() != conversation_id) { continue; } conversations_list_->setCurrentItem(item); break; } + for (int index = 0; index < contacts_panel_list_->count(); ++index) { + auto* item = contacts_panel_list_->item(index); + if (item->data(kRoleConversationId).toString() != conversation_id) { + continue; + } + contacts_panel_list_->setCurrentItem(item); + break; + } RefreshConversationSelectionStyles(); UpdateThreadHeader(); UpdateCallControls(); } -QWidget* ChatWidget::CreateConversationItemWidget(const ConversationListItemView& item, bool selected) const { +QWidget* ChatWidget::CreateConversationItemWidget( + const ConversationListItemView& item, + bool selected, + bool removable) { auto* row = new QWidget(); row->setObjectName("conversationRow"); row->setProperty("selected", selected); @@ -389,6 +638,20 @@ QWidget* ChatWidget::CreateConversationItemWidget(const ConversationListItemView layout->setContentsMargins(10, 8, 10, 8); layout->setSpacing(8); + auto* status_dot = new QLabel(row); + status_dot->setObjectName("presenceDot"); + const QString normalized_status = NormalizePresenceStatus(item.status); + status_dot->setProperty("state", normalized_status); + status_dot->setFixedSize(10, 10); + status_dot->setStyleSheet( + QString( + "background:%1;" + "border-radius:5px;" + "min-width:10px;max-width:10px;" + "min-height:10px;max-height:10px;") + .arg(PresenceColorForStatus(normalized_status))); + layout->addWidget(status_dot, 0, Qt::AlignTop); + auto* text_col = new QVBoxLayout(); text_col->setSpacing(2); @@ -409,6 +672,24 @@ QWidget* ChatWidget::CreateConversationItemWidget(const ConversationListItemView layout->addLayout(text_col, 1); layout->addWidget(time); + if (removable) { + auto* remove_button = new QPushButton("x", row); + remove_button->setObjectName("conversationRemoveButton"); + remove_button->setFixedSize(18, 18); + remove_button->setFocusPolicy(Qt::NoFocus); + remove_button->setToolTip( + item.conversation_type.trimmed().toLower() == "group" + ? "Leave group" + : "Remove DM"); + connect(remove_button, &QPushButton::clicked, row, [this, item]() { + emit ConversationRemoveRequested( + item.id, + item.conversation_type.trimmed().toLower(), + item.can_manage_members, + item.title.trimmed()); + }); + layout->addWidget(remove_button, 0, Qt::AlignTop); + } return row; } @@ -436,11 +717,40 @@ QWidget* ChatWidget::CreateThreadMessageWidget(const ThreadMessageView& message) bubble_layout->addWidget(meta); } - auto* body = new QLabel(message.body, bubble); - body->setObjectName("messageBody"); - body->setWordWrap(true); - body->setTextInteractionFlags(Qt::TextSelectableByMouse); - bubble_layout->addWidget(body); + QString file_name; + QByteArray file_bytes; + qint64 file_size = 0; + const bool is_file = DecodeFileMessageBody(message.body, &file_name, &file_bytes, &file_size); + if (is_file) { + auto* file_label = new QLabel( + QString("Encrypted file: %1 (%2)").arg(file_name, HumanFileSize(file_size)), + bubble); + file_label->setObjectName("messageBody"); + file_label->setWordWrap(true); + bubble_layout->addWidget(file_label); + + auto* save_button = new QPushButton("Save file", bubble); + save_button->setObjectName("secondaryButton"); + QObject::connect(save_button, &QPushButton::clicked, bubble, [file_name, file_bytes, bubble]() { + const QString target_path = QFileDialog::getSaveFileName(bubble, "Save file", file_name); + if (target_path.trimmed().isEmpty()) { + return; + } + QFile output(target_path); + if (!output.open(QIODevice::WriteOnly)) { + return; + } + output.write(file_bytes); + output.close(); + }); + bubble_layout->addWidget(save_button, 0, Qt::AlignLeft); + } else { + auto* body = new QLabel(message.body, bubble); + body->setObjectName("messageBody"); + body->setWordWrap(true); + body->setTextInteractionFlags(Qt::TextSelectableByMouse); + bubble_layout->addWidget(body); + } if (message.outgoing) { row_layout->addStretch(1); @@ -456,20 +766,44 @@ QWidget* ChatWidget::CreateThreadMessageWidget(const ThreadMessageView& message) void ChatWidget::SetConversationList(const std::vector& conversations) { const QString selected_id = SelectedConversation(); QSignalBlocker signal_blocker(conversations_list_); + QSignalBlocker contact_signal_blocker(contacts_panel_list_); conversations_list_->clear(); + contacts_panel_list_->clear(); for (const auto& item : conversations) { - auto* list_item = new QListWidgetItem(conversations_list_); - list_item->setData(Qt::UserRole, item.id); - list_item->setData(Qt::UserRole + 1, item.title); - list_item->setData(Qt::UserRole + 2, item.subtitle); - list_item->setSizeHint(QSize(200, 64)); - - auto* row = CreateConversationItemWidget(item, item.id == selected_id); - conversations_list_->setItemWidget(list_item, row); + auto* dm_item = new QListWidgetItem(conversations_list_); + dm_item->setData(kRoleConversationId, item.id); + dm_item->setData(kRoleTitle, item.title); + dm_item->setData(kRoleSubtitle, item.subtitle); + dm_item->setData(kRoleStatus, item.status); + dm_item->setData(kRoleConversationType, item.conversation_type); + dm_item->setData(kRoleCanManageMembers, item.can_manage_members); + dm_item->setData(kRolePeerAddress, item.peer_address); + dm_item->setData(kRoleGroupName, item.group_name); + dm_item->setData(kRoleMemberCount, item.member_count); + dm_item->setSizeHint(QSize(200, 64)); + auto* dm_row = CreateConversationItemWidget(item, item.id == selected_id, true); + conversations_list_->setItemWidget(dm_item, dm_row); + + auto* contact_item = new QListWidgetItem(contacts_panel_list_); + contact_item->setData(kRoleConversationId, item.id); + contact_item->setData(kRoleTitle, item.title); + contact_item->setData(kRoleSubtitle, item.subtitle); + contact_item->setData(kRoleStatus, item.status); + contact_item->setData(kRoleConversationType, item.conversation_type); + contact_item->setData(kRoleCanManageMembers, item.can_manage_members); + contact_item->setData(kRolePeerAddress, item.peer_address); + contact_item->setData(kRoleGroupName, item.group_name); + contact_item->setData(kRoleMemberCount, item.member_count); + contact_item->setSizeHint(QSize(200, 54)); + ConversationListItemView contact_view = item; + contact_view.subtitle = QString("Status: %1").arg(NormalizePresenceStatus(item.status)); + auto* contact_row = CreateConversationItemWidget(contact_view, item.id == selected_id, false); + contacts_panel_list_->setItemWidget(contact_item, contact_row); if (!selected_id.isEmpty() && item.id == selected_id) { - conversations_list_->setCurrentItem(list_item); + conversations_list_->setCurrentItem(dm_item); + contacts_panel_list_->setCurrentItem(contact_item); } } @@ -532,39 +866,124 @@ void ChatWidget::SetTimelineHasMessages(bool has_messages) { } void ChatWidget::UpdateThreadHeader() { + if (contacts_mode_active_) { + thread_title_label_->setText("Contacts"); + thread_title_label_->setReadOnly(true); + thread_title_label_->setFocusPolicy(Qt::NoFocus); + thread_title_label_->setCursor(Qt::ArrowCursor); + thread_title_label_->setToolTip({}); + return; + } const auto* selected = conversations_list_->currentItem(); if (selected == nullptr) { thread_title_label_->setText("Select a conversation"); + thread_title_label_->setReadOnly(true); + thread_title_label_->setFocusPolicy(Qt::NoFocus); + thread_title_label_->setCursor(Qt::ArrowCursor); + thread_title_label_->setToolTip({}); empty_state_label_->setText("Select a conversation to start chatting."); return; } - const QString title = selected->data(Qt::UserRole + 1).toString(); + const QString title = selected->data(kRoleTitle).toString(); thread_title_label_->setText(title.isEmpty() ? "Direct message" : title); + const QString conversation_type = selected->data(kRoleConversationType).toString().trimmed().toLower(); + const bool can_manage_members = selected->data(kRoleCanManageMembers).toBool(); + const bool owner_group = conversation_type == "group" && can_manage_members; + thread_title_label_->setReadOnly(!owner_group); + thread_title_label_->setFocusPolicy(owner_group ? Qt::ClickFocus : Qt::NoFocus); + thread_title_label_->setCursor(owner_group ? Qt::IBeamCursor : Qt::ArrowCursor); + thread_title_label_->setToolTip( + owner_group ? "Press Enter to rename this group." : QString()); if (messages_list_->count() == 0) { empty_state_label_->setText("No messages yet. Send the first encrypted message."); } } void ChatWidget::RefreshConversationSelectionStyles() { - for (int index = 0; index < conversations_list_->count(); ++index) { - auto* item = conversations_list_->item(index); - auto* row = conversations_list_->itemWidget(item); + RefreshListSelectionStyles(conversations_list_); + RefreshListSelectionStyles(contacts_panel_list_); +} + +void ChatWidget::RefreshListSelectionStyles(QListWidget* list) { + if (list == nullptr) { + return; + } + for (int index = 0; index < list->count(); ++index) { + auto* item = list->item(index); + auto* row = list->itemWidget(item); if (row == nullptr) { continue; } - const bool selected = conversations_list_->currentRow() == index; + const bool selected = list->currentRow() == index; row->setProperty("selected", selected); row->style()->unpolish(row); row->style()->polish(row); } } +void ChatWidget::SyncContactSelection(const QString& conversation_id) { + QSignalBlocker signal_blocker(contacts_panel_list_); + if (conversation_id.isEmpty()) { + contacts_panel_list_->setCurrentItem(nullptr); + return; + } + + for (int index = 0; index < contacts_panel_list_->count(); ++index) { + auto* item = contacts_panel_list_->item(index); + if (item->data(kRoleConversationId).toString() != conversation_id) { + continue; + } + contacts_panel_list_->setCurrentItem(item); + return; + } + contacts_panel_list_->setCurrentItem(nullptr); +} + +void ChatWidget::SetContactsMode(bool enabled) { + contacts_mode_active_ = enabled; + if (content_stack_ != nullptr) { + content_stack_->setCurrentWidget(contacts_mode_active_ ? contacts_panel_ : chat_panel_); + } + UpdateContactsButtonState(); + UpdateThreadHeader(); + UpdateCallControls(); +} + +void ChatWidget::UpdateContactsButtonState() { + if (contacts_button_ != nullptr) { + contacts_button_->setChecked(contacts_mode_active_); + } +} + +QString ChatWidget::SelectedConversationType() const { + const auto* item = conversations_list_->currentItem(); + if (item == nullptr) { + return "direct"; + } + const QString value = item->data(kRoleConversationType).toString().trimmed().toLower(); + return value.isEmpty() ? "direct" : value; +} + +bool ChatWidget::SelectedConversationCanManageMembers() const { + const auto* item = conversations_list_->currentItem(); + if (item == nullptr) { + return false; + } + return item->data(kRoleCanManageMembers).toBool(); +} + void ChatWidget::UpdateCallControls() { const QString state = call_state_.state.trimmed().toLower(); const bool has_selection = !SelectedConversation().isEmpty(); + const bool in_contacts_mode = contacts_mode_active_; + const QString conversation_type = SelectedConversationType(); + const bool selected_direct = conversation_type == "direct"; + const bool selected_group = conversation_type == "group"; + const bool selected_owner_managed_group = selected_group && SelectedConversationCanManageMembers(); const bool panel_visible = state == "incoming_ringing" || state == "outgoing_ringing" || state == "active" || state == "ending"; + const bool idle_and_chat_ready = !in_contacts_mode && has_selection && state == "idle" && !panel_visible; call_status_label_->setText(FormatCallStatusText(call_state_)); call_status_label_->setProperty("state", CallStatusStyleState(call_state_)); @@ -572,8 +991,12 @@ void ChatWidget::UpdateCallControls() { call_status_label_->style()->polish(call_status_label_); call_panel_->setVisible(panel_visible); - call_button_->setVisible(state == "idle" && !panel_visible); - call_button_->setEnabled(has_selection); + call_button_->setVisible(idle_and_chat_ready); + call_button_->setEnabled(has_selection && !in_contacts_mode); + group_button_->setVisible(idle_and_chat_ready && selected_direct); + group_button_->setEnabled(idle_and_chat_ready && selected_direct); + invite_button_->setVisible(idle_and_chat_ready && selected_owner_managed_group); + invite_button_->setEnabled(idle_and_chat_ready && selected_owner_managed_group); accept_call_button_->setVisible(state == "incoming_ringing"); decline_call_button_->setVisible(state == "incoming_ringing"); mute_call_button_->setVisible(state == "active"); @@ -588,6 +1011,35 @@ void ChatWidget::UpdateCallControls() { call_panel_title_->setText(QString("Voice with %1").arg(peer_title)); call_panel_subtitle_->setText(FormatCallPanelSubtitle(call_state_)); + while (call_participants_layout_ != nullptr && call_participants_layout_->count() > 0) { + QLayoutItem* item = call_participants_layout_->takeAt(0); + if (item == nullptr) { + continue; + } + if (item->widget() != nullptr) { + item->widget()->deleteLater(); + } + delete item; + } + const bool show_participants = state == "active" && !call_state_.participants.empty(); + if (show_participants && call_participants_layout_ != nullptr) { + for (const auto& participant : call_state_.participants) { + QString label = participant.label.trimmed(); + if (label.isEmpty()) { + label = participant.user_address.trimmed(); + } + if (label.isEmpty()) { + label = participant.self ? "You" : "Member"; + } + auto* chip = new QLabel(label, call_participants_panel_); + chip->setObjectName("callParticipantChip"); + chip->setProperty("self", participant.self); + call_participants_layout_->addWidget(chip); + } + call_participants_layout_->addStretch(1); + } + call_participants_panel_->setVisible(show_participants); + QChar initial = peer_title.isEmpty() ? QChar('#') : peer_title.at(0).toUpper(); call_panel_avatar_->setText(QString(initial)); @@ -596,6 +1048,128 @@ void ChatWidget::UpdateCallControls() { call_panel_->style()->polish(call_panel_); } +void ChatWidget::ShowGroupInviteDialog(const GroupInvitePickerView& picker) { + QDialog dialog(this); + dialog.setWindowTitle("Invite Contacts"); + dialog.setModal(true); + dialog.resize(500, 600); + + auto* layout = new QVBoxLayout(&dialog); + layout->setContentsMargins(14, 14, 14, 14); + layout->setSpacing(10); + + const int remaining_slots = std::max(0, picker.remaining_slots); + auto* header_label = new QLabel(QString("You can add %1 more members.").arg(remaining_slots), &dialog); + header_label->setWordWrap(true); + layout->addWidget(header_label); + + auto* search_input = new QLineEdit(&dialog); + search_input->setPlaceholderText("Type username or address"); + layout->addWidget(search_input); + + auto* list = new QListWidget(&dialog); + list->setSelectionMode(QAbstractItemView::NoSelection); + list->setVerticalScrollMode(QAbstractItemView::ScrollPerPixel); + layout->addWidget(list, 1); + + auto* hint_label = new QLabel(&dialog); + hint_label->setWordWrap(true); + layout->addWidget(hint_label); + + auto* actions = new QHBoxLayout(); + actions->addStretch(1); + auto* cancel_button = new QPushButton("Cancel", &dialog); + cancel_button->setObjectName("secondaryButton"); + auto* invite_action_button = new QPushButton("Invite", &dialog); + invite_action_button->setObjectName("primaryButton"); + actions->addWidget(cancel_button); + actions->addWidget(invite_action_button); + layout->addLayout(actions); + + std::vector selected_addresses; + selected_addresses.reserve(static_cast(remaining_slots)); + const auto is_selected = [&selected_addresses](const QString& value) { + return std::find(selected_addresses.begin(), selected_addresses.end(), value) != selected_addresses.end(); + }; + const auto update_action_state = [&]() { + invite_action_button->setEnabled(remaining_slots > 0 && !selected_addresses.empty()); + }; + + const auto repopulate = [&]() { + const QString query = search_input->text().trimmed().toLower(); + list->clear(); + int shown = 0; + for (const auto& candidate : picker.candidates) { + const QString haystack = QString("%1 %2 %3") + .arg(candidate.title, candidate.subtitle, candidate.peer_address) + .trimmed() + .toLower(); + if (!query.isEmpty() && !haystack.contains(query)) { + continue; + } + auto* item = new QListWidgetItem( + QString("%1\n%2").arg(candidate.title, candidate.subtitle.isEmpty() ? candidate.peer_address : candidate.subtitle), + list); + item->setData(Qt::UserRole, candidate.peer_address); + item->setFlags(item->flags() | Qt::ItemIsUserCheckable | Qt::ItemIsEnabled); + item->setCheckState(is_selected(candidate.peer_address) ? Qt::Checked : Qt::Unchecked); + shown += 1; + } + + if (remaining_slots <= 0) { + hint_label->setText("Member limit reached. No more invites can be sent."); + } else if (picker.candidates.empty()) { + hint_label->setText("No direct-message contacts are available to invite."); + } else if (shown == 0) { + hint_label->setText("No contacts match your search."); + } else { + hint_label->setText(QString("Select up to %1 contact(s).").arg(remaining_slots)); + } + update_action_state(); + }; + + connect(search_input, &QLineEdit::textChanged, &dialog, [&]() { + repopulate(); + }); + connect(list, &QListWidget::itemChanged, &dialog, [&](QListWidgetItem* item) { + if (item == nullptr) { + return; + } + const QString address = item->data(Qt::UserRole).toString().trimmed().toLower(); + if (address.isEmpty()) { + return; + } + if (item->checkState() == Qt::Checked) { + if (is_selected(address)) { + return; + } + if (static_cast(selected_addresses.size()) >= remaining_slots) { + item->setCheckState(Qt::Unchecked); + hint_label->setText(QString("You can only invite up to %1 contact(s).").arg(remaining_slots)); + return; + } + selected_addresses.push_back(address); + } else { + selected_addresses.erase( + std::remove(selected_addresses.begin(), selected_addresses.end(), address), + selected_addresses.end()); + } + update_action_state(); + }); + connect(cancel_button, &QPushButton::clicked, &dialog, &QDialog::reject); + connect(invite_action_button, &QPushButton::clicked, &dialog, [&]() { + if (selected_addresses.empty()) { + return; + } + emit InviteGroupMembersRequested(selected_addresses); + dialog.accept(); + }); + + repopulate(); + update_action_state(); + dialog.exec(); +} + void ChatWidget::SetConnectionStatus(const QString& status) { status_label_->setText(status); @@ -617,9 +1191,46 @@ void ChatWidget::SetConnectionStatus(const QString& status) { void ChatWidget::SetCallState(const CallStateView& state) { call_state_ = state; + if (contacts_mode_active_ && call_state_.state.trimmed().toLower() != "idle") { + SetContactsMode(false); + } UpdateCallControls(); } +void ChatWidget::SetUserStatus(const QString& status) { + const QString normalized = NormalizePresenceStatus(status); + QSignalBlocker blocker(presence_combo_); + int target_index = -1; + for (int index = 0; index < presence_combo_->count(); ++index) { + if (NormalizePresenceStatus(presence_combo_->itemData(index).toString()) == normalized) { + target_index = index; + break; + } + } + if (target_index < 0) { + target_index = 0; + } + presence_combo_->setCurrentIndex(target_index); + UpdatePresenceIndicator(normalized); +} + +void ChatWidget::UpdatePresenceIndicator(const QString& status) { + if (presence_indicator_ == nullptr) { + return; + } + const QString normalized = NormalizePresenceStatus(status); + presence_indicator_->setProperty("state", normalized); + presence_indicator_->setStyleSheet( + QString( + "background:%1;" + "border-radius:5px;" + "min-width:10px;max-width:10px;" + "min-height:10px;max-height:10px;") + .arg(PresenceColorForStatus(normalized))); + presence_indicator_->style()->unpolish(presence_indicator_); + presence_indicator_->style()->polish(presence_indicator_); +} + void ChatWidget::SetIdentity(const QString& user_address) { if (user_address.trimmed().isEmpty()) { identity_value_.clear(); diff --git a/client-cpp-gui/src/ui/main_window.cpp b/client-cpp-gui/src/ui/main_window.cpp index f9e92b8..33abc1b 100644 --- a/client-cpp-gui/src/ui/main_window.cpp +++ b/client-cpp-gui/src/ui/main_window.cpp @@ -1,5 +1,7 @@ #include "blackwire/ui/main_window.hpp" +#include + #include #include #include @@ -56,8 +58,82 @@ MainWindow::MainWindow(ApplicationController& controller, QWidget* parent) controller_.SelectConversation(id); }); + connect(chat_widget_, &ChatWidget::CreateGroupFromDmRequested, this, [this]() { + if (controller_.CreateGroupFromCurrentDm()) { + chat_widget_->ShowBanner("Group created.", "info"); + } + }); + + connect(chat_widget_, &ChatWidget::GroupInviteDialogRequested, this, [this]() { + if (!controller_.IsSelectedConversationOwnerManagedGroup()) { + chat_widget_->ShowBanner("Only the group owner can invite members.", "warning"); + return; + } + const auto picker = controller_.LoadInvitableContactsForCurrentGroup(QString()); + chat_widget_->ShowGroupInviteDialog(picker); + }); + + connect(chat_widget_, &ChatWidget::InviteGroupMembersRequested, this, [this](const std::vector& addresses) { + if (controller_.InviteContactsToCurrentGroup(addresses)) { + chat_widget_->ShowBanner("Invites sent.", "info"); + } + }); + + connect(chat_widget_, &ChatWidget::GroupRenameRequested, this, [this](const QString& name) { + if (controller_.RenameSelectedGroup(name)) { + chat_widget_->ShowBanner("Group name updated.", "info"); + } + }); + + connect( + chat_widget_, + &ChatWidget::ConversationRemoveRequested, + this, + [this](const QString& conversation_id, const QString& conversation_type, bool can_manage_members, const QString& title) { + const QString id = conversation_id.trimmed(); + if (id.isEmpty()) { + return; + } + const QString normalized_type = conversation_type.trimmed().toLower(); + if (normalized_type == "group") { + const QString question = can_manage_members + ? "Leave this group? Ownership will be transferred automatically if needed." + : "Leave this group?"; + const auto answer = QMessageBox::question( + this, + "Leave Group", + question); + if (answer != QMessageBox::Yes) { + return; + } + if (controller_.LeaveGroupConversation(id)) { + chat_widget_->ShowBanner("Left group DM.", "info"); + } + return; + } + + const auto answer = QMessageBox::question( + this, + "Remove DM", + QString("Remove %1 from Messages? This only clears local cache.").arg(title.trimmed().isEmpty() ? "this DM" : title)); + if (answer != QMessageBox::Yes) { + return; + } + if (controller_.DismissDirectConversation(id)) { + chat_widget_->ShowBanner("DM removed from Messages.", "info"); + } + }); + connect(chat_widget_, &ChatWidget::SendMessageRequested, this, [this]() { - controller_.SendMessageToPeer(chat_widget_->PeerUsername(), chat_widget_->ComposeText()); + controller_.SendMessageToPeer(QString(), chat_widget_->ComposeText()); + }); + + connect(chat_widget_, &ChatWidget::SendFileRequested, this, [this](const QString& file_path) { + controller_.SendFileToPeer(QString(), file_path); + }); + + connect(chat_widget_, &ChatWidget::UserStatusChanged, this, [this](const QString& status) { + controller_.SetPresenceStatus(status); }); connect(chat_widget_, &ChatWidget::StartVoiceCallRequested, this, [this]() { @@ -82,11 +158,13 @@ MainWindow::MainWindow(ApplicationController& controller, QWidget* parent) connect(chat_widget_, &ChatWidget::SettingsRequested, this, [this]() { controller_.LoadAudioDevices(); + controller_.LoadAccountDevices(); settings_dialog_->SetIdentity(controller_.UserDisplayId()); settings_dialog_->SetServerUrl(controller_.BaseUrl()); settings_dialog_->SetDeviceInfo(controller_.DeviceLabel(), controller_.DeviceId()); settings_dialog_->SetConnectionStatus(controller_.ConnectionStatus()); settings_dialog_->SetDiagnostics(controller_.DiagnosticsReport()); + settings_dialog_->SetIntegrityWarning(last_integrity_warning_); settings_dialog_->SetAcceptMessagesFromStrangers(controller_.AcceptMessagesFromStrangers()); settings_dialog_->exec(); }); @@ -103,6 +181,17 @@ MainWindow::MainWindow(ApplicationController& controller, QWidget* parent) controller_.SetAcceptMessagesFromStrangers(enabled); }); + connect(settings_dialog_, &SettingsDialog::RevokeDeviceRequested, this, [this](const QString& device_uid) { + const auto answer = QMessageBox::question( + this, + "Kick Device", + QString("Revoke device %1? It will be disconnected immediately.").arg(device_uid)); + if (answer != QMessageBox::Yes) { + return; + } + controller_.RevokeDevice(device_uid); + }); + connect(settings_dialog_, &SettingsDialog::LogoutRequested, this, [this]() { controller_.Logout(); }); @@ -120,6 +209,9 @@ MainWindow::MainWindow(ApplicationController& controller, QWidget* parent) chat_widget_->SetIdentity(authenticated ? controller_.UserDisplayId() : QString()); chat_widget_->SetSettingsVisible(authenticated); if (!authenticated) { + last_integrity_warning_.clear(); + settings_dialog_->SetIntegrityWarning(last_integrity_warning_); + settings_dialog_->SetAccountDevices(std::vector{}, QString()); login_widget_->SetBaseUrl(controller_.BaseUrl()); stack->setCurrentWidget(login_widget_); return; @@ -201,6 +293,10 @@ MainWindow::MainWindow(ApplicationController& controller, QWidget* parent) statusBar()->showMessage(status, 3000); }); + connect(&controller_, &ApplicationController::UserPresenceChanged, this, [this](const QString& status) { + chat_widget_->SetUserStatus(status); + }); + connect(&controller_, &ApplicationController::CallStateChanged, this, [this](const CallStateView& state) { chat_widget_->SetCallState(state); }); @@ -222,6 +318,14 @@ MainWindow::MainWindow(ApplicationController& controller, QWidget* parent) settings_dialog_->SetAudioDevices(inputs, outputs); }); + connect( + &controller_, + &ApplicationController::AccountDevicesChanged, + this, + [this](const std::vector& devices) { + settings_dialog_->SetAccountDevices(devices, controller_.DeviceId()); + }); + connect( &controller_, &ApplicationController::AudioDevicePreferenceChanged, @@ -235,6 +339,13 @@ MainWindow::MainWindow(ApplicationController& controller, QWidget* parent) statusBar()->showMessage(error, 5000); }); + connect(&controller_, &ApplicationController::IntegrityWarningOccurred, this, [this](const QString& warning) { + last_integrity_warning_ = warning.trimmed(); + settings_dialog_->SetIntegrityWarning(last_integrity_warning_); + chat_widget_->ShowBanner(last_integrity_warning_, "warning"); + statusBar()->showMessage(last_integrity_warning_, 7000); + }); + connect(&controller_, &ApplicationController::ErrorOccurred, this, [this, stack](const QString& error) { if (stack->currentWidget() == login_widget_ || stack->currentWidget() == device_widget_) { QMessageBox::warning(this, "Blackwire", error); diff --git a/client-cpp-gui/src/ui/settings_dialog.cpp b/client-cpp-gui/src/ui/settings_dialog.cpp index 23a43ce..7b9dd77 100644 --- a/client-cpp-gui/src/ui/settings_dialog.cpp +++ b/client-cpp-gui/src/ui/settings_dialog.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -16,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -24,6 +26,7 @@ #include #include +#include "blackwire/models/dto.hpp" #include "blackwire/models/view_models.hpp" namespace blackwire { @@ -37,6 +40,10 @@ constexpr int kAudioMeterDecayMs = 45; constexpr int kAudioMeterDecayStep = 5; constexpr int kMonitorOutputQueueLimit = 96000; +std::string DeviceUid(const DeviceOut& device) { + return device.device_uid.empty() ? device.id : device.device_uid; +} + } // namespace SettingsDialog::SettingsDialog(QWidget* parent) : QDialog(parent) { @@ -71,6 +78,25 @@ SettingsDialog::SettingsDialog(QWidget* parent) : QDialog(parent) { QGuiApplication::clipboard()->setText(diagnostics_); } }); + QObject::connect(account_devices_list_, &QListWidget::itemSelectionChanged, this, [this]() { + const auto* item = account_devices_list_->currentItem(); + if (item == nullptr) { + selected_account_device_uid_.clear(); + revoke_device_button_->setEnabled(false); + return; + } + selected_account_device_uid_ = item->data(Qt::UserRole).toString(); + const QString status = item->data(Qt::UserRole + 1).toString().trimmed().toLower(); + const bool is_self = selected_account_device_uid_ == current_device_uid_; + revoke_device_button_->setEnabled( + !selected_account_device_uid_.isEmpty() && status == "active" && !is_self); + }); + QObject::connect(revoke_device_button_, &QPushButton::clicked, this, [this]() { + if (selected_account_device_uid_.isEmpty()) { + return; + } + emit RevokeDeviceRequested(selected_account_device_uid_); + }); QObject::connect(apply_audio_button_, &QPushButton::clicked, this, [this]() { emit ApplyAudioDevicesRequested( input_device_combo_->currentData().toString(), @@ -211,6 +237,22 @@ void SettingsDialog::BuildMyAccountPage() { form->addRow("Connection", connection_status_value_); layout->addLayout(form); + auto* devices_title = new QLabel("Authorized Devices", page); + devices_title->setObjectName("conversationSubtitle"); + layout->addWidget(devices_title); + + account_devices_list_ = new QListWidget(page); + account_devices_list_->setObjectName("conversationList"); + account_devices_list_->setSelectionMode(QAbstractItemView::SingleSelection); + account_devices_list_->setVerticalScrollMode(QAbstractItemView::ScrollPerPixel); + account_devices_list_->setMinimumHeight(170); + layout->addWidget(account_devices_list_); + + revoke_device_button_ = new QPushButton("Kick Selected Device", page); + revoke_device_button_->setObjectName("dangerButton"); + revoke_device_button_->setEnabled(false); + layout->addWidget(revoke_device_button_, 0, Qt::AlignLeft); + auto* actions = new QHBoxLayout(); actions->setSpacing(8); copy_id_button_ = new QPushButton("Copy ID", page); @@ -317,6 +359,17 @@ void SettingsDialog::BuildPrivacyPage() { copy_diagnostics_button_ = new QPushButton("Copy Diagnostics", page); copy_diagnostics_button_->setObjectName("secondaryButton"); layout->addWidget(copy_diagnostics_button_, 0, Qt::AlignLeft); + + auto* integrity_title = new QLabel("Integrity Status", page); + integrity_title->setObjectName("conversationSubtitle"); + layout->addWidget(integrity_title); + + integrity_warning_value_ = new QLabel("No integrity warnings observed in this session.", page); + integrity_warning_value_->setObjectName("connectionPill"); + integrity_warning_value_->setProperty("state", "connected"); + integrity_warning_value_->setWordWrap(true); + layout->addWidget(integrity_warning_value_); + layout->addStretch(1); tabs_stack_->addWidget(page); @@ -396,6 +449,36 @@ void SettingsDialog::SetDiagnostics(const QString& diagnostics) { diagnostics_ = diagnostics; } +void SettingsDialog::SetAccountDevices(const std::vector& devices, const QString& current_device_uid) { + if (account_devices_list_ == nullptr || revoke_device_button_ == nullptr) { + return; + } + current_device_uid_ = current_device_uid.trimmed(); + selected_account_device_uid_.clear(); + revoke_device_button_->setEnabled(false); + + const QSignalBlocker blocker(account_devices_list_); + account_devices_list_->clear(); + + for (const auto& device : devices) { + const QString uid = QString::fromStdString(DeviceUid(device)); + const QString label = + QString::fromStdString(device.label).trimmed().isEmpty() ? "(unlabeled)" : QString::fromStdString(device.label).trimmed(); + const QString status = QString::fromStdString(device.status).trimmed().toLower().isEmpty() + ? "active" + : QString::fromStdString(device.status).trimmed().toLower(); + const QString suffix = uid == current_device_uid_ ? " (this device)" : ""; + + auto* row = new QListWidgetItem(account_devices_list_); + row->setData(Qt::UserRole, uid); + row->setData(Qt::UserRole + 1, status); + row->setText(QString("%1 [%2]%3\n%4").arg(label, status, suffix, uid)); + if (uid == current_device_uid_) { + account_devices_list_->setCurrentItem(row); + } + } +} + void SettingsDialog::SetAudioDevices( const std::vector& input_devices, const std::vector& output_devices) { @@ -448,6 +531,21 @@ void SettingsDialog::SetAcceptMessagesFromStrangers(bool enabled) { accept_messages_checkbox_->setChecked(enabled); } +void SettingsDialog::SetIntegrityWarning(const QString& warning) { + if (integrity_warning_value_ == nullptr) { + return; + } + if (warning.trimmed().isEmpty()) { + integrity_warning_value_->setText("No integrity warnings observed in this session."); + integrity_warning_value_->setProperty("state", "connected"); + } else { + integrity_warning_value_->setText(warning.trimmed()); + integrity_warning_value_->setProperty("state", "warning"); + } + integrity_warning_value_->style()->unpolish(integrity_warning_value_); + integrity_warning_value_->style()->polish(integrity_warning_value_); +} + bool SettingsDialog::AcceptMessagesFromStrangers() const { if (accept_messages_checkbox_ == nullptr) { return true; diff --git a/client-cpp-gui/src/ui/theme.cpp b/client-cpp-gui/src/ui/theme.cpp index 0ee8007..4cad099 100644 --- a/client-cpp-gui/src/ui/theme.cpp +++ b/client-cpp-gui/src/ui/theme.cpp @@ -49,6 +49,10 @@ QLineEdit, QPlainTextEdit, QListWidget, QComboBox { padding: 6px 8px; } +QComboBox#statusCombo { + min-width: 96px; +} + QLineEdit:focus, QPlainTextEdit:focus, QComboBox:focus { border-color: #5865f2; } @@ -97,6 +101,12 @@ QPushButton#secondaryButton { background: #3a3d44; } +QPushButton#secondaryButton:checked { + background: #4e5ada; + border: 1px solid #4a57d8; + color: #ffffff; +} + QPushButton#dangerButton { background: #da373c; border: 1px solid #b32f33; @@ -148,12 +158,24 @@ QLabel#dmSidebarTitle { color: #ffffff; } +QLabel#contactsSidebarTitle { + font-size: 13px; + font-weight: 700; + color: #d6d9de; +} + QWidget#identityCard { background: #232428; border: 1px solid #1b1c20; border-radius: 10px; } +QWidget#contactsToggleCard { + background: #232428; + border: 1px solid #1b1c20; + border-radius: 10px; +} + QLabel#identityLabel { color: #d6d9de; } @@ -162,10 +184,19 @@ QWidget#chatPane { background: #313338; } -QLabel#threadTitle { +QLabel#threadTitle, QLineEdit#threadTitle { font-size: 15px; font-weight: 700; color: #ffffff; + background: transparent; + border: 1px solid transparent; + border-radius: 6px; + padding: 2px 6px; +} + +QLineEdit#threadTitle:focus { + border-color: #5865f2; + background: #232428; } QLabel#connectionPill { @@ -261,6 +292,22 @@ QLabel#callPanelSubtitle { color: #c6cbd2; } +QLabel#callParticipantChip { + background: #31353d; + border: 1px solid #1b1c20; + border-radius: 10px; + padding: 3px 8px; + color: #d6d9de; + font-size: 12px; +} + +QLabel#callParticipantChip[self=\"true\"] { + background: #4250c7; + border-color: #4a57d8; + color: #ffffff; + font-weight: 600; +} + QLabel#chatBanner { border-radius: 8px; padding: 8px 10px; @@ -285,6 +332,11 @@ QListWidget#conversationList { border: 1px solid #1b1c20; } +QListWidget#contactsList { + background: #2b2d31; + border: 1px solid #1b1c20; +} + QListWidget#conversationList::item { border-radius: 8px; } @@ -293,6 +345,14 @@ QListWidget#conversationList::item:selected { background: #404249; } +QListWidget#contactsList::item { + border-radius: 8px; +} + +QListWidget#contactsList::item:selected { + background: #404249; +} + QWidget#conversationRow { background: transparent; border-radius: 8px; @@ -316,6 +376,50 @@ QLabel#conversationTime { font-size: 11px; } +QPushButton#conversationRemoveButton { + min-width: 18px; + max-width: 18px; + min-height: 18px; + max-height: 18px; + border-radius: 9px; + border: 1px solid #2e3138; + background: #2c2f35; + color: #cfd3da; + padding: 0px; + font-weight: 600; +} + +QPushButton#conversationRemoveButton:hover { + background: #c63237; + border-color: #b32f33; + color: #ffffff; +} + +QLabel#presenceDot { + border-radius: 5px; + min-width: 10px; + max-width: 10px; + min-height: 10px; + max-height: 10px; + background: #6a6f78; +} + +QLabel#presenceDot[state=\"active\"] { + background: #2d7d46; +} + +QLabel#presenceDot[state=\"inactive\"] { + background: #d4a63c; +} + +QLabel#presenceDot[state=\"offline\"] { + background: #6a6f78; +} + +QLabel#presenceDot[state=\"dnd\"] { + background: #9e2c31; +} + QListWidget#messageList { background: #313338; border: none; diff --git a/client-cpp-gui/src/util/message_view.cpp b/client-cpp-gui/src/util/message_view.cpp index 64597bd..a8e2a5e 100644 --- a/client-cpp-gui/src/util/message_view.cpp +++ b/client-cpp-gui/src/util/message_view.cpp @@ -29,6 +29,18 @@ bool IsBefore(const LocalMessage& lhs, const LocalMessage& rhs) { return lhs.id < rhs.id; } +QString LabelFromSenderAddress(const std::string& sender_address) { + const QString address = QString::fromStdString(sender_address).trimmed(); + if (address.isEmpty()) { + return {}; + } + const int at = address.indexOf('@'); + if (at > 0) { + return address.left(at).trimmed(); + } + return address; +} + } // namespace QString ExtractLegacyPlaintext(const QString& rendered_text) { @@ -84,11 +96,19 @@ std::vector BuildThreadMessageViews( view.created_at_iso = QString::fromStdString(item->created_at); view.created_at_display = FormatThreadTimestamp(view.created_at_iso); view.outgoing = item->sender_user_id == self_user_id; - view.sender_label = view.outgoing ? "You" : (peer_sender_label.trimmed().isEmpty() ? "Peer" : peer_sender_label.trimmed()); + if (view.outgoing) { + view.sender_label = "You"; + } else { + const QString from_address = LabelFromSenderAddress(item->sender_address); + view.sender_label = from_address.trimmed().isEmpty() + ? (peer_sender_label.trimmed().isEmpty() ? "Peer" : peer_sender_label.trimmed()) + : from_address.trimmed(); + } const QString plaintext = QString::fromStdString(item->plaintext); view.body = plaintext.isEmpty() ? ExtractLegacyPlaintext(QString::fromStdString(item->rendered_text)) : plaintext; - view.grouped_with_previous = !views.empty() && views.back().outgoing == view.outgoing; + view.grouped_with_previous = + !views.empty() && views.back().outgoing == view.outgoing && views.back().sender_label == view.sender_label; views.push_back(view); } diff --git a/client-cpp-gui/src/ws/qt_ws_client.cpp b/client-cpp-gui/src/ws/qt_ws_client.cpp index 99536c8..ae2fe91 100644 --- a/client-cpp-gui/src/ws/qt_ws_client.cpp +++ b/client-cpp-gui/src/ws/qt_ws_client.cpp @@ -128,6 +128,7 @@ QtWsClient::QtWsClient(QObject* parent) : QObject(parent) { void QtWsClient::SetHandlers( MessageHandler on_message, CallIncomingHandler on_call_incoming, + CallGroupStateHandler on_call_group_state, CallRingingHandler on_call_ringing, CallAcceptedHandler on_call_accepted, CallRejectedHandler on_call_rejected, @@ -135,10 +136,15 @@ void QtWsClient::SetHandlers( CallEndedHandler on_call_ended, CallAudioHandler on_call_audio, CallErrorHandler on_call_error, + CallWebRtcOfferHandler on_call_webrtc_offer, + CallWebRtcAnswerHandler on_call_webrtc_answer, + CallWebRtcIceHandler on_call_webrtc_ice, + GroupRenamedHandler on_group_renamed, ErrorHandler on_error, StatusHandler on_status) { on_message_ = std::move(on_message); on_call_incoming_ = std::move(on_call_incoming); + on_call_group_state_ = std::move(on_call_group_state); on_call_ringing_ = std::move(on_call_ringing); on_call_accepted_ = std::move(on_call_accepted); on_call_rejected_ = std::move(on_call_rejected); @@ -146,6 +152,10 @@ void QtWsClient::SetHandlers( on_call_ended_ = std::move(on_call_ended); on_call_audio_ = std::move(on_call_audio); on_call_error_ = std::move(on_call_error); + on_call_webrtc_offer_ = std::move(on_call_webrtc_offer); + on_call_webrtc_answer_ = std::move(on_call_webrtc_answer); + on_call_webrtc_ice_ = std::move(on_call_webrtc_ice); + on_group_renamed_ = std::move(on_group_renamed); on_error_ = std::move(on_error); on_status_ = std::move(on_status); } @@ -199,6 +209,24 @@ void QtWsClient::SendCallAudioChunk(const VoiceAudioChunk& chunk) { socket_.sendTextMessage(QString::fromStdString(payload.dump())); } +void QtWsClient::SendCallWebRtcOffer(const VoiceCallWebRtcOffer& offer) { + nlohmann::json payload = offer; + payload["type"] = "call.webrtc.offer"; + socket_.sendTextMessage(QString::fromStdString(payload.dump())); +} + +void QtWsClient::SendCallWebRtcAnswer(const VoiceCallWebRtcAnswer& answer) { + nlohmann::json payload = answer; + payload["type"] = "call.webrtc.answer"; + socket_.sendTextMessage(QString::fromStdString(payload.dump())); +} + +void QtWsClient::SendCallWebRtcIce(const VoiceCallWebRtcIce& ice) { + nlohmann::json payload = ice; + payload["type"] = "call.webrtc.ice"; + socket_.sendTextMessage(QString::fromStdString(payload.dump())); +} + void QtWsClient::ScheduleReconnect() { reconnect_attempt_ += 1; const int max_ms = 30000; @@ -214,7 +242,7 @@ QUrl QtWsClient::BuildWsUrl() const { url.setScheme(scheme); url.setHost(base.host()); url.setPort(base.port(base.scheme() == "https" ? 443 : 80)); - url.setPath("/api/v1/ws"); + url.setPath("/api/v2/ws"); return url; } @@ -238,6 +266,14 @@ void QtWsClient::HandleTextMessage(const QString& message_text) { return; } + if (type == "call.group.incoming") { + WsEventCallIncoming event = payload.get(); + if (on_call_incoming_) { + on_call_incoming_(event); + } + return; + } + if (type == "call.ringing") { WsEventCallRinging event = payload.get(); if (on_call_ringing_) { @@ -278,6 +314,34 @@ void QtWsClient::HandleTextMessage(const QString& message_text) { return; } + if (type == "call.group.ended") { + WsEventCallEnded event = payload.get(); + if (on_call_ended_) { + on_call_ended_(event); + } + return; + } + + if (type == "call.group.state") { + WsEventCallGroupState event = payload.get(); + if (on_call_group_state_) { + on_call_group_state_(event); + } + return; + } + + if (type == "call.group.participant") { + return; + } + + if (type == "conversation.group.renamed") { + WsEventGroupRenamed event = payload.get(); + if (on_group_renamed_) { + on_group_renamed_(event); + } + return; + } + if (type == "call.audio") { WsEventCallAudio event = payload.get(); if (on_call_audio_) { @@ -297,6 +361,30 @@ void QtWsClient::HandleTextMessage(const QString& message_text) { return; } + if (type == "call.webrtc.offer") { + WsEventCallWebRtcOffer event = payload.get(); + if (on_call_webrtc_offer_) { + on_call_webrtc_offer_(event); + } + return; + } + + if (type == "call.webrtc.answer") { + WsEventCallWebRtcAnswer event = payload.get(); + if (on_call_webrtc_answer_) { + on_call_webrtc_answer_(event); + } + return; + } + + if (type == "call.webrtc.ice") { + WsEventCallWebRtcIce event = payload.get(); + if (on_call_webrtc_ice_) { + on_call_webrtc_ice_(event); + } + return; + } + if (type == "error") { if (on_error_) { on_error_("WebSocket server error: " + message_text.toStdString()); diff --git a/client-cpp-gui/tests/test_crypto.cpp b/client-cpp-gui/tests/test_crypto.cpp index 3eddc62..f651e4e 100644 --- a/client-cpp-gui/tests/test_crypto.cpp +++ b/client-cpp-gui/tests/test_crypto.cpp @@ -12,3 +12,14 @@ TEST(CryptoServiceTest, EncryptDecryptRoundTrip) { EXPECT_EQ(decrypted, plaintext); } + +TEST(CryptoServiceTest, SignVerifyRoundTrip) { + blackwire::SodiumCryptoService crypto; + const auto keys = crypto.GenerateDeviceKeys(); + + const std::string payload = "signed-payload"; + const auto signature = crypto.SignDetached(keys.ik_ed25519_private_b64, payload); + + EXPECT_TRUE(crypto.VerifyDetached(keys.ik_ed25519_public_b64, payload, signature)); + EXPECT_FALSE(crypto.VerifyDetached(keys.ik_ed25519_public_b64, payload + "x", signature)); +} diff --git a/client-cpp-gui/tests/test_envelope_serialization.cpp b/client-cpp-gui/tests/test_envelope_serialization.cpp index cf743a2..539bd42 100644 --- a/client-cpp-gui/tests/test_envelope_serialization.cpp +++ b/client-cpp-gui/tests/test_envelope_serialization.cpp @@ -23,6 +23,31 @@ TEST(EnvelopeSerializationTest, RoundTripPreservesFields) { EXPECT_EQ(restored.envelope.client_message_id, request.envelope.client_message_id); } +TEST(EnvelopeSerializationTest, V2MessageSendSerializationRoundTrip) { + blackwire::MessageSendRequest request; + request.conversation_id = "conv-v2"; + request.client_message_id = "client-v2"; + request.sent_at_ms = 12345; + request.sender_prev_hash = ""; + request.sender_chain_hash = "abc123"; + + blackwire::CipherEnvelope env; + env.recipient_user_address = "bob@peer.onion"; + env.recipient_device_uid = "device-bob"; + env.ciphertext_b64 = "Zm9v"; + env.aad_b64.clear(); + env.signature_b64 = "c2ln"; + env.sender_device_pubkey = "cHVi"; + request.envelopes = {env}; + + const nlohmann::json json = request; + const auto restored = json.get(); + + ASSERT_EQ(restored.envelopes.size(), 1); + EXPECT_EQ(restored.client_message_id, "client-v2"); + EXPECT_EQ(restored.envelopes.front().recipient_device_uid, "device-bob"); +} + TEST(EnvelopeSerializationTest, UserOutParsesCanonicalIdentityWhenPresent) { const auto json = nlohmann::json::parse( R"({"id":"u1","username":"alice","created_at":"2026-02-17T00:00:00Z","user_address":"alice@peer.onion","home_server_onion":"peer.onion"})"); @@ -71,3 +96,15 @@ TEST(EnvelopeSerializationTest, MessageOutDefaultsSenderAddressWhenMissing) { const auto message = json.get(); EXPECT_TRUE(message.sender_address.empty()); } + +TEST(EnvelopeSerializationTest, ConversationMemberOutParsesNullableFields) { + const auto json = nlohmann::json::parse( + R"({"id":"m1","member_user_id":null,"member_address":"alice@local.invalid","member_server_onion":"local.invalid","role":"member","status":"invited","invited_by_address":"bob@local.invalid","invited_at":"2026-02-24T10:00:00Z","joined_at":null,"left_at":null,"updated_at":"2026-02-24T10:00:00Z"})"); + + const auto member = json.get(); + EXPECT_EQ(member.id, "m1"); + EXPECT_TRUE(member.member_user_id.empty()); + EXPECT_EQ(member.member_address, "alice@local.invalid"); + EXPECT_TRUE(member.joined_at.empty()); + EXPECT_TRUE(member.left_at.empty()); +} diff --git a/client-cpp-gui/tests/test_message_view.cpp b/client-cpp-gui/tests/test_message_view.cpp index e0847dc..d365f3f 100644 --- a/client-cpp-gui/tests/test_message_view.cpp +++ b/client-cpp-gui/tests/test_message_view.cpp @@ -35,6 +35,30 @@ TEST(MessageViewTest, GroupsConsecutiveMessagesFromSameSender) { EXPECT_FALSE(views[2].outgoing); } +TEST(MessageViewTest, UsesSenderAddressLabelForGroupMembersAndSeparatesSpeakers) { + blackwire::LocalMessage first; + first.id = "1"; + first.sender_user_id = "peer-a-id"; + first.sender_address = "alice@server-a.onion"; + first.created_at = "2026-02-14T10:00:00Z"; + first.plaintext = "hello"; + + blackwire::LocalMessage second; + second.id = "2"; + second.sender_user_id = "peer-b-id"; + second.sender_address = "bob@server-b.onion"; + second.created_at = "2026-02-14T10:01:00Z"; + second.plaintext = "hi"; + + std::vector input = {first, second}; + const auto views = blackwire::BuildThreadMessageViews(input, "self-id", "Member"); + + ASSERT_EQ(views.size(), 2U); + EXPECT_EQ(views[0].sender_label.toStdString(), "alice"); + EXPECT_EQ(views[1].sender_label.toStdString(), "bob"); + EXPECT_FALSE(views[1].grouped_with_previous); +} + TEST(MessageViewTest, FormatsInvalidTimestampWithFallback) { const auto display = blackwire::FormatThreadTimestamp("not-a-real-time"); EXPECT_EQ(display.toStdString(), "not-a-real-time"); diff --git a/client-cpp-gui/tests/test_state_store_compat.cpp b/client-cpp-gui/tests/test_state_store_compat.cpp index f9e4183..cba723c 100644 --- a/client-cpp-gui/tests/test_state_store_compat.cpp +++ b/client-cpp-gui/tests/test_state_store_compat.cpp @@ -11,6 +11,7 @@ TEST(StateCompatTest, LoadsLegacyLocalMessageWithoutPlaintext) { const auto message = json.get(); EXPECT_EQ(message.id, "m1"); EXPECT_EQ(message.plaintext, ""); + EXPECT_EQ(message.plaintext_cache_b64, ""); } TEST(StateCompatTest, DoesNotPersistLocalMessagePlaintextToDisk) { @@ -24,11 +25,21 @@ TEST(StateCompatTest, DoesNotPersistLocalMessagePlaintextToDisk) { const nlohmann::json json = message; EXPECT_FALSE(json.contains("plaintext")); + EXPECT_TRUE(json.contains("plaintext_cache_b64")); const auto roundtrip = json.get(); EXPECT_TRUE(roundtrip.plaintext.empty()); } +TEST(StateCompatTest, LoadsEncryptedPlaintextCacheWithoutPlaintext) { + const auto json = nlohmann::json::parse( + R"({"id":"m3","conversation_id":"c3","sender_user_id":"u3","created_at":"2026-02-14T12:00:00Z","rendered_text":"[2026] Me: encrypted","plaintext_cache_b64":"Zm9v"})"); + + const auto message = json.get(); + EXPECT_EQ(message.plaintext, ""); + EXPECT_EQ(message.plaintext_cache_b64, "Zm9v"); +} + TEST(StateCompatTest, LoadsLegacyClientStateWithoutAudioPreferences) { const auto json = nlohmann::json::parse( R"({"base_url":"http://localhost:8000","has_user":false,"has_device":false,"conversations":[]})"); @@ -86,6 +97,28 @@ TEST(StateCompatTest, PersistsSocialPreferencesAndBlockedConversationsRoundTrip) EXPECT_TRUE(roundtrip.blocked_conversation_ids.contains("conv-2")); } +TEST(StateCompatTest, LoadsLegacyClientStateWithoutDismissedConversations) { + const auto json = nlohmann::json::parse( + R"({"base_url":"http://localhost:8000","has_user":false,"has_device":false,"conversations":[]})"); + + const auto state = json.get(); + EXPECT_TRUE(state.dismissed_conversation_ids.empty()); +} + +TEST(StateCompatTest, PersistsDismissedConversationsRoundTrip) { + blackwire::ClientState state; + state.base_url = "http://localhost:8000"; + state.dismissed_conversation_ids.insert("dm-1"); + state.dismissed_conversation_ids.insert("dm-2"); + + const nlohmann::json serialized = state; + ASSERT_TRUE(serialized.contains("dismissed_conversation_ids")); + + const auto roundtrip = serialized.get(); + EXPECT_TRUE(roundtrip.dismissed_conversation_ids.contains("dm-1")); + EXPECT_TRUE(roundtrip.dismissed_conversation_ids.contains("dm-2")); +} + TEST(StateCompatTest, ConversationMetaPeerAddressRoundTrip) { blackwire::ConversationMeta meta; meta.peer_username = "alice"; diff --git a/index.html b/index.html new file mode 100644 index 0000000..1ba8276 --- /dev/null +++ b/index.html @@ -0,0 +1,608 @@ + + + + + + Blackwire v0.2 | Security-First Federated Messaging + + + + + + + + + + +
+
+
+
v0.2 Release
+

Device-bound trust and ratchet-ready federated messaging.

+

+ Blackwire v0.2 ships signed per-device envelopes, true multi-device accounts with revoke, + asymmetric device-bound JWT sessions, integrity chains, and staged ratchet/WebRTC rollout. +

+ +
+
+
+ +
+
$ docker compose -f infra/docker-compose.yml up --build
+$env:BLACKWIRE_ENABLE_RATCHET_V2B1 = "true"
+$env:BLACKWIRE_ENABLE_WEBRTC_V2B2 = "true"
+
+v2 auth: /api/v2/auth/login -> /api/v2/auth/bind-device
+v2 send: /api/v2/messages/send (sealedbox_v0_2a | ratchet_v0_2b1)
+v2 federation: /api/v2/federation/well-known
+v2 websocket: /api/v2/ws (Bearer auth)
+
+
+ +
+
+

v0.2 Highlights

+

+ v0.2a security milestones are complete, with v0.2b ratchet and WebRTC paths delivered behind rollout gates. +

+
+
+
+

Client-Signed Envelopes

+

Outbound messages are signed per device, and receivers verify signatures before accept/decrypt.

+
+
+

Sender Key Pinning

+

TOFU pinning by sender_device_uid raises integrity warnings on signing-key changes.

+
+
+

Multi-Device + Revoke

+

Accounts now support multiple active devices with immediate revoke/kick enforcement.

+
+
+

Device-Bound JWT Sessions

+

Bootstrap-to-bind flow issues asymmetric tokens with device identity claims validated on every request.

+
+
+

Integrity Chain Detection

+

Signed chain fields detect gaps/reordering/tampering and surface explicit integrity warnings in UI.

+
+
+

Ratchet + WebRTC Rollout

+

Prekey/X3DH + Double Ratchet and WebRTC signaling/media are wired through v2 dual-stack contracts.

+
+
+
+ +
+
+

v0.2 Rollout Model

+

+ Migration is incremental: v2 features run in parallel with v1 compatibility while ratchet and WebRTC move to default. +

+
+
+
+

Parallel API Surface

+

/api/v2 is implemented without removing /api/v1, allowing staged client migration.

+
+
+

Mode Negotiation

+

encryption_mode supports sealed-box fallback and ratchet mode with explicit policy gates.

+
+
+

Federation Trust Binding

+

v2 federation well-known metadata binds trust to Tor onion identity while preserving signed relay controls.

+
+
+

Voice Transport Migration

+

WebRTC signaling and media are added in v2 while legacy WS PCM can be disabled after stabilization.

+
+
+
+ +
+
+

API Surface (/api/v2)

+

+ Selected v0.2 release endpoints. Legacy /api/v1 remains available for compatibility. +

+
+
    +
  • POST /api/v2/auth/register
  • +
  • POST /api/v2/auth/login
  • +
  • POST /api/v2/auth/bind-device
  • +
  • POST /api/v2/auth/refresh
  • +
  • POST /api/v2/auth/logout
  • +
  • POST /api/v2/devices/register
  • +
  • GET /api/v2/devices
  • +
  • POST /api/v2/devices/{device_uid}/revoke
  • +
  • GET /api/v2/me
  • +
  • GET /api/v2/users/resolve-devices?peer_address=...
  • +
  • GET /api/v2/users/resolve-prekeys?peer_address=...
  • +
  • POST /api/v2/keys/prekeys/upload
  • +
  • POST /api/v2/messages/send
  • +
  • GET /api/v2/conversations
  • +
  • POST /api/v2/conversations/dm
  • +
  • GET /api/v2/conversations/{conversation_id}/messages
  • +
  • GET /api/v2/federation/well-known
  • +
  • GET /api/v2/federation/users/{username}/devices
  • +
  • GET /api/v2/federation/users/{username}/prekeys
  • +
  • POST /api/v2/federation/messages/relay
  • +
  • POST /api/v2/federation/calls/webrtc-offer
  • +
  • POST /api/v2/federation/calls/webrtc-answer
  • +
  • POST /api/v2/federation/calls/webrtc-ice
  • +
  • POST /api/v2/federation/group-calls/webrtc-offer
  • +
  • POST /api/v2/federation/group-calls/webrtc-answer
  • +
  • POST /api/v2/federation/group-calls/webrtc-ice
  • +
  • WS /api/v2/ws
  • +
+
+ +
+
+

Quick Start

+

+ Enable v0.2 rollout flags, run the stack, and verify v2 federation and messaging paths. +

+
+
+
+

Prepare Environment

+ Copy-Item infra/example.env infra/.env + ./infra/randomize-env-secrets.ps1 +
+
+

Enable v0.2b Flags

+ BLACKWIRE_ENABLE_RATCHET_V2B1=true + BLACKWIRE_ENABLE_WEBRTC_V2B2=true +
+
+

Set Voice Mode

+ BLACKWIRE_ENABLE_LEGACY_CALL_AUDIO_WS=false + BLACKWIRE_WEBRTC_ICE_SERVERS_JSON=[{"urls":"stun:stun.l.google.com:19302"}] +
+
+

Start + Verify v2

+ docker compose -f infra/docker-compose.yml up --build + http://localhost:8000/api/v2/federation/well-known +
+
+
+ v0.2a security milestones are implemented; v0.2b ratchet/WebRTC paths are available with feature gates while /api/v1 remains compatible. +
+
+
+ +
+
+
Blackwire v0.2 release showcase site for GitHub Pages.
+
Updated | See README.md for full setup details.
+
+
+ + + + diff --git a/infra/example.env b/infra/example.env index e344142..45123e0 100644 --- a/infra/example.env +++ b/infra/example.env @@ -4,18 +4,51 @@ BLACKWIRE_AUTO_CREATE_TABLES=false BLACKWIRE_JWT_SECRET_KEY=replace-with-strong-secret-min-32-bytes BLACKWIRE_ACCESS_TOKEN_MINUTES=15 BLACKWIRE_REFRESH_TOKEN_DAYS=30 +BLACKWIRE_JWT_PRIVATE_KEY_PEM= +BLACKWIRE_JWT_PUBLIC_KEY_PEM= +BLACKWIRE_V2_BOOTSTRAP_TOKEN_SECONDS=120 +BLACKWIRE_V2_ACCESS_TOKEN_MINUTES=10 +BLACKWIRE_V2_REFRESH_TOKEN_DAYS=30 +BLACKWIRE_V2_BIND_REQUEST_SKEW_SECONDS=300 BLACKWIRE_MESSAGE_TTL_DAYS=7 +BLACKWIRE_ATTACHMENT_INLINE_MAX_BYTES=10485760 +BLACKWIRE_ATTACHMENT_HARD_CEILING_BYTES=33554432 +BLACKWIRE_MAX_CIPHERTEXT_BYTES=20971520 +BLACKWIRE_MAX_CLIENT_MESSAGE_BODY_BYTES=31457280 +BLACKWIRE_MAX_FEDERATION_BODY_BYTES=31457280 +BLACKWIRE_MESSAGE_BYTES_PER_MINUTE_PER_USER=52428800 +BLACKWIRE_FEDERATION_BYTES_PER_MINUTE_PER_PEER=104857600 +BLACKWIRE_PENDING_QUEUE_MAX_COPIES_PER_DEVICE=500 +BLACKWIRE_PENDING_QUEUE_MAX_BYTES_PER_DEVICE=67108864 +BLACKWIRE_ENABLE_RATCHET_V2B1=false +BLACKWIRE_RATCHET_REQUIRE_FOR_LOCAL=false +BLACKWIRE_RATCHET_REQUIRE_FOR_FEDERATION=false +BLACKWIRE_ENABLE_WEBRTC_V2B2=false +BLACKWIRE_WEBRTC_ICE_SERVERS_JSON= +BLACKWIRE_ENABLE_LEGACY_CALL_AUDIO_WS=true +BLACKWIRE_ENABLE_GROUP_DM_V2C=false +BLACKWIRE_ENABLE_GROUP_CALL_V2C=false +BLACKWIRE_GROUP_MAX_MEMBERS=32 +BLACKWIRE_GROUP_CALL_MAX_PARTICIPANTS=8 +BLACKWIRE_GROUP_INVITE_RATE_PER_MINUTE=120 +BLACKWIRE_GROUP_CREATE_RATE_PER_HOUR=30 +BLACKWIRE_GROUP_CALL_START_RATE_PER_MINUTE=30 +BLACKWIRE_GROUP_CALL_RING_TTL_SECONDS=45 BLACKWIRE_RATE_LIMIT_PER_MINUTE=120 BLACKWIRE_USE_REDIS_RATE_LIMIT=false BLACKWIRE_REDIS_URL= BLACKWIRE_TOR_ENABLED=true BLACKWIRE_TOR_SOCKS5_URL=socks5h://127.0.0.1:9050 BLACKWIRE_TOR_HS_HOSTNAME_FILE=/var/lib/tor/hidden_service/hostname +BLACKWIRE_TOR_HS_ED25519_SECRET_KEY_FILE=/var/lib/tor/hidden_service/hs_ed25519_secret_key BLACKWIRE_TOR_HS_HOSTNAME_WAIT_SECONDS=45 BLACKWIRE_FEDERATION_SERVER_ONION=local.invalid -# Required when BLACKWIRE_TOR_ENABLED=true. Must be base64-encoded 32-byte Ed25519 seed. -# Example generation (python): python -c "import os,base64;print(base64.b64encode(os.urandom(32)).decode())" -BLACKWIRE_FEDERATION_SIGNING_PRIVATE_KEY_B64=your_base64_private_key +BLACKWIRE_LOCAL_SERVER_ALIASES=localhost,127.0.0.1,localhost:8000,127.0.0.1:8000 +# Add LAN/WAN authorities used by local clients (for example: 192.168.1.10,203.0.113.5:8000) +# so alias routing treats them as local. +# Optional fallback for non-Tor deployments. Must be base64-encoded 32-byte Ed25519 seed. +# In Tor mode the key is derived from hidden-service key material. +BLACKWIRE_FEDERATION_SIGNING_PRIVATE_KEY_B64= BLACKWIRE_FEDERATION_REQUEST_SKEW_SECONDS=60 BLACKWIRE_FEDERATION_NONCE_TTL_SECONDS=300 BLACKWIRE_FEDERATION_OUTBOX_POLL_INTERVAL_SECONDS=5 diff --git a/server/app/api/conversations.py b/server/app/api/conversations.py index 30e4387..5bdd2b5 100644 --- a/server/app/api/conversations.py +++ b/server/app/api/conversations.py @@ -34,11 +34,13 @@ async def create_dm( current_user: User = Depends(get_current_user), ) -> ConversationOut: await rate_limiter.enforce(client_rate_limit_key(request, f"conversation-create:{current_user.id}")) + request_authority = request.headers.get("host", "").strip().lower() conversation = await conversation_service.create_dm( session, current_user, payload.peer_username, payload.peer_address, + request_authority=request_authority, ) peer_username = await conversation_service.peer_username_for_user(session, conversation, current_user.id) peer_address = await conversation_service.peer_address_for_user(session, conversation, current_user.id) diff --git a/server/app/api/users.py b/server/app/api/users.py index ed3b256..f79a935 100644 --- a/server/app/api/users.py +++ b/server/app/api/users.py @@ -50,7 +50,11 @@ async def resolve_user_device( current_user: User = Depends(get_current_user), ) -> UserDeviceLookup: await rate_limiter.enforce(client_rate_limit_key(request, f"user-device-resolve:{current_user.id}")) - result = await device_service.resolve_device_by_peer_address(session, peer_address) + result = await device_service.resolve_device_by_peer_address( + session, + peer_address, + request_authority=request.headers.get("host", "").strip().lower(), + ) if result is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Active device not found") return result diff --git a/server/app/api_v2/__init__.py b/server/app/api_v2/__init__.py new file mode 100644 index 0000000..3f37e1d --- /dev/null +++ b/server/app/api_v2/__init__.py @@ -0,0 +1,3 @@ +from app.api_v2 import auth, conversations, devices, federation, messages, presence, users + +__all__ = ["auth", "conversations", "devices", "federation", "messages", "presence", "users"] diff --git a/server/app/api_v2/auth.py b/server/app/api_v2/auth.py new file mode 100644 index 0000000..f693dd2 --- /dev/null +++ b/server/app/api_v2/auth.py @@ -0,0 +1,91 @@ +from fastapi import APIRouter, Depends, HTTPException, Request, status +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.utils import client_rate_limit_key +from app.dependencies import db_session, get_bootstrap_user_v2 +from app.models.user import User +from app.schemas.user import UserOut +from app.schemas.v2_auth import ( + BindDeviceRequestV2, + BootstrapAuthResponseV2, + DeviceAuthResponseV2, + LoginRequestV2, + LogoutRequestV2, + RefreshRequestV2, + RegisterRequestV2, +) +from app.services.auth_service import AuthServiceError +from app.services.auth_service_v2 import auth_service_v2 +from app.services.rate_limit import rate_limiter + +router = APIRouter(prefix="/api/v2/auth", tags=["auth-v2"]) + + +@router.post("/register", response_model=BootstrapAuthResponseV2, status_code=status.HTTP_201_CREATED) +async def register( + payload: RegisterRequestV2, + request: Request, + session: AsyncSession = Depends(db_session), +) -> BootstrapAuthResponseV2: + await rate_limiter.enforce(client_rate_limit_key(request, "v2-auth-register")) + try: + user = await auth_service_v2.register(session, payload.username, payload.password) + tokens = await auth_service_v2.issue_bootstrap(user) + return BootstrapAuthResponseV2(user=UserOut.from_user(user), tokens=tokens) + except AuthServiceError as exc: + raise HTTPException(status_code=exc.status_code, detail=str(exc)) from exc + + +@router.post("/login", response_model=BootstrapAuthResponseV2) +async def login( + payload: LoginRequestV2, + request: Request, + session: AsyncSession = Depends(db_session), +) -> BootstrapAuthResponseV2: + await rate_limiter.enforce(client_rate_limit_key(request, "v2-auth-login")) + try: + user = await auth_service_v2.authenticate(session, payload.username, payload.password) + tokens = await auth_service_v2.issue_bootstrap(user) + return BootstrapAuthResponseV2(user=UserOut.from_user(user), tokens=tokens) + except AuthServiceError as exc: + raise HTTPException(status_code=exc.status_code, detail=str(exc)) from exc + + +@router.post("/bind-device", response_model=DeviceAuthResponseV2) +async def bind_device( + payload: BindDeviceRequestV2, + request: Request, + session: AsyncSession = Depends(db_session), + bootstrap_user: User = Depends(get_bootstrap_user_v2), +) -> DeviceAuthResponseV2: + await rate_limiter.enforce(client_rate_limit_key(request, "v2-auth-bind-device")) + try: + tokens = await auth_service_v2.bind_device(session, bootstrap_user, payload) + return DeviceAuthResponseV2(user=UserOut.from_user(bootstrap_user), tokens=tokens) + except AuthServiceError as exc: + raise HTTPException(status_code=exc.status_code, detail=str(exc)) from exc + + +@router.post("/refresh", response_model=DeviceAuthResponseV2) +async def refresh( + payload: RefreshRequestV2, + request: Request, + session: AsyncSession = Depends(db_session), +) -> DeviceAuthResponseV2: + await rate_limiter.enforce(client_rate_limit_key(request, "v2-auth-refresh")) + try: + user, tokens = await auth_service_v2.refresh(session, payload.refresh_token) + return DeviceAuthResponseV2(user=UserOut.from_user(user), tokens=tokens) + except AuthServiceError as exc: + raise HTTPException(status_code=exc.status_code, detail=str(exc)) from exc + + +@router.post("/logout", status_code=status.HTTP_204_NO_CONTENT) +async def logout( + payload: LogoutRequestV2, + request: Request, + session: AsyncSession = Depends(db_session), +) -> None: + await rate_limiter.enforce(client_rate_limit_key(request, "v2-auth-logout")) + await auth_service_v2.logout(session, payload.refresh_token) + diff --git a/server/app/api_v2/conversations.py b/server/app/api_v2/conversations.py new file mode 100644 index 0000000..949588f --- /dev/null +++ b/server/app/api_v2/conversations.py @@ -0,0 +1,274 @@ +from fastapi import APIRouter, Depends, HTTPException, Query, Request +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.utils import client_rate_limit_key +from app.dependencies import AuthenticatedDeviceContextV2, db_session, get_current_device_context_v2 +from app.schemas.v2_conversation import ( + ConversationMemberOutV2, + ConversationOutV2, + ConversationRecipientsOutV2, + CreateDMConversationRequestV2, + CreateGroupConversationRequestV2, + GroupInviteRequestV2, + GroupLeaveRequestV2, + GroupRenameRequestV2, +) +from app.schemas.v2_message import MessageDeviceCopyOutV2, MessageEventOutV2 +from app.services.conversation_service import conversation_service +from app.services.group_conversation_service import group_conversation_service +from app.services.message_service_v2 import message_service_v2 +from app.services.rate_limit import rate_limiter + +router = APIRouter(prefix="/api/v2/conversations", tags=["conversations-v2"]) + + +async def _serialize_conversation( + session: AsyncSession, + conversation, + context: AuthenticatedDeviceContextV2, +) -> ConversationOutV2: + if conversation.conversation_type == "group": + return await group_conversation_service.to_conversation_out( + session, + conversation=conversation, + user=context.user, + ) + peer_username = await conversation_service.peer_username_for_user(session, conversation, context.user.id) + peer_address = await conversation_service.peer_address_for_user(session, conversation, context.user.id) + peer_server_onion = await conversation_service.peer_server_onion_for_user(session, conversation, context.user.id) + return ConversationOutV2( + id=conversation.id, + kind=conversation.kind, + user_a_id=conversation.user_a_id or "", + user_b_id=conversation.user_b_id or "", + local_user_id=conversation.local_user_id or "", + created_at=conversation.created_at, + peer_username=peer_username, + peer_server_onion=peer_server_onion, + peer_address=peer_address, + conversation_type="direct", + group_uid=None, + group_name="", + member_count=0, + membership_state="none", + can_manage_members=False, + origin_server_onion="", + owner_address="", + ) + + +@router.post("/dm", response_model=ConversationOutV2) +async def create_dm( + payload: CreateDMConversationRequestV2, + request: Request, + session: AsyncSession = Depends(db_session), + context: AuthenticatedDeviceContextV2 = Depends(get_current_device_context_v2), +) -> ConversationOutV2: + await rate_limiter.enforce(client_rate_limit_key(request, f"v2-conversation-create:{context.user.id}")) + request_authority = request.headers.get("host", "").strip().lower() + conversation = await conversation_service.create_dm( + session, + context.user, + payload.peer_username, + payload.peer_address, + request_authority=request_authority, + ) + return await _serialize_conversation(session, conversation, context) + + +@router.post("/group", response_model=ConversationOutV2) +async def create_group( + payload: CreateGroupConversationRequestV2, + request: Request, + session: AsyncSession = Depends(db_session), + context: AuthenticatedDeviceContextV2 = Depends(get_current_device_context_v2), +) -> ConversationOutV2: + await rate_limiter.enforce(client_rate_limit_key(request, f"v2-group-create:{context.user.id}")) + conversation = await group_conversation_service.create_group( + session, + context.user, + name=payload.name, + member_addresses=payload.member_addresses, + request_authority=request.headers.get("host", "").strip().lower(), + ) + return await _serialize_conversation(session, conversation, context) + + +@router.get("", response_model=list[ConversationOutV2]) +async def list_conversations( + request: Request, + limit: int = Query(default=50, ge=1, le=200), + offset: int = Query(default=0, ge=0), + session: AsyncSession = Depends(db_session), + context: AuthenticatedDeviceContextV2 = Depends(get_current_device_context_v2), +) -> list[ConversationOutV2]: + await rate_limiter.enforce(client_rate_limit_key(request, f"v2-conversation-list:{context.user.id}")) + conversations = await conversation_service.list_for_user(session, context.user, limit=limit, offset=offset) + return [await _serialize_conversation(session, item, context) for item in conversations] + + +@router.get("/{conversation_id}/members", response_model=list[ConversationMemberOutV2]) +async def list_members( + conversation_id: str, + request: Request, + session: AsyncSession = Depends(db_session), + context: AuthenticatedDeviceContextV2 = Depends(get_current_device_context_v2), +) -> list[ConversationMemberOutV2]: + await rate_limiter.enforce(client_rate_limit_key(request, f"v2-group-members:{context.user.id}")) + conversation = await conversation_service.get_by_id(session, conversation_id) + if conversation is None or conversation.conversation_type != "group": + raise HTTPException(status_code=404, detail="Conversation not found") + members = await group_conversation_service.list_members(session, conversation, context.user.id) + return [group_conversation_service.member_out(row) for row in members] + + +@router.post("/{conversation_id}/members/invite", response_model=list[ConversationMemberOutV2]) +async def invite_members( + conversation_id: str, + payload: GroupInviteRequestV2, + request: Request, + session: AsyncSession = Depends(db_session), + context: AuthenticatedDeviceContextV2 = Depends(get_current_device_context_v2), +) -> list[ConversationMemberOutV2]: + await rate_limiter.enforce( + client_rate_limit_key(request, f"v2-group-invite:{context.user.id}"), + limit=group_conversation_service.settings.group_invite_rate_per_minute, + ) + conversation = await conversation_service.get_by_id(session, conversation_id) + if conversation is None or conversation.conversation_type != "group": + raise HTTPException(status_code=404, detail="Conversation not found") + rows = await group_conversation_service.invite_members( + session, + conversation=conversation, + owner=context.user, + member_addresses=payload.member_addresses, + request_authority=request.headers.get("host", "").strip().lower(), + ) + return [group_conversation_service.member_out(row) for row in rows] + + +@router.post("/{conversation_id}/members/{member_address:path}/remove", response_model=ConversationMemberOutV2) +async def remove_member( + conversation_id: str, + member_address: str, + request: Request, + session: AsyncSession = Depends(db_session), + context: AuthenticatedDeviceContextV2 = Depends(get_current_device_context_v2), +) -> ConversationMemberOutV2: + await rate_limiter.enforce(client_rate_limit_key(request, f"v2-group-remove:{context.user.id}")) + conversation = await conversation_service.get_by_id(session, conversation_id) + if conversation is None or conversation.conversation_type != "group": + raise HTTPException(status_code=404, detail="Conversation not found") + row = await group_conversation_service.remove_member( + session, + conversation=conversation, + owner=context.user, + member_address=member_address, + ) + return group_conversation_service.member_out(row) + + +@router.post("/{conversation_id}/invites/accept", response_model=ConversationMemberOutV2) +async def accept_invite( + conversation_id: str, + request: Request, + session: AsyncSession = Depends(db_session), + context: AuthenticatedDeviceContextV2 = Depends(get_current_device_context_v2), +) -> ConversationMemberOutV2: + await rate_limiter.enforce(client_rate_limit_key(request, f"v2-group-accept:{context.user.id}")) + conversation = await conversation_service.get_by_id(session, conversation_id) + if conversation is None or conversation.conversation_type != "group": + raise HTTPException(status_code=404, detail="Conversation not found") + row = await group_conversation_service.accept_invite(session, conversation=conversation, user=context.user) + return group_conversation_service.member_out(row) + + +@router.post("/{conversation_id}/leave", response_model=ConversationMemberOutV2) +async def leave_group( + conversation_id: str, + request: Request, + payload: GroupLeaveRequestV2 | None = None, + session: AsyncSession = Depends(db_session), + context: AuthenticatedDeviceContextV2 = Depends(get_current_device_context_v2), +) -> ConversationMemberOutV2: + await rate_limiter.enforce(client_rate_limit_key(request, f"v2-group-leave:{context.user.id}")) + conversation = await conversation_service.get_by_id(session, conversation_id) + if conversation is None or conversation.conversation_type != "group": + raise HTTPException(status_code=404, detail="Conversation not found") + row = await group_conversation_service.leave( + session, + conversation=conversation, + user=context.user, + reason=(payload.reason if payload is not None else None), + ) + return group_conversation_service.member_out(row) + + +@router.post("/{conversation_id}/rename", response_model=ConversationOutV2) +async def rename_group( + conversation_id: str, + payload: GroupRenameRequestV2, + request: Request, + session: AsyncSession = Depends(db_session), + context: AuthenticatedDeviceContextV2 = Depends(get_current_device_context_v2), +) -> ConversationOutV2: + await rate_limiter.enforce(client_rate_limit_key(request, f"v2-group-rename:{context.user.id}")) + conversation = await conversation_service.get_by_id(session, conversation_id) + if conversation is None or conversation.conversation_type != "group": + raise HTTPException(status_code=404, detail="Conversation not found") + renamed = await group_conversation_service.rename(session, conversation=conversation, owner=context.user, name=payload.name) + return await _serialize_conversation(session, renamed, context) + + +@router.get("/{conversation_id}/recipients", response_model=ConversationRecipientsOutV2) +async def get_recipients( + conversation_id: str, + request: Request, + session: AsyncSession = Depends(db_session), + context: AuthenticatedDeviceContextV2 = Depends(get_current_device_context_v2), +) -> ConversationRecipientsOutV2: + await rate_limiter.enforce(client_rate_limit_key(request, f"v2-group-recipients:{context.user.id}")) + conversation = await conversation_service.get_by_id(session, conversation_id) + if conversation is None: + raise HTTPException(status_code=404, detail="Conversation not found") + if conversation.conversation_type != "group": + raise HTTPException(status_code=400, detail="Recipients endpoint only applies to group conversations") + return await group_conversation_service.members_for_recipients( + session, + conversation=conversation, + user=context.user, + current_device_uid=context.device.id, + ) + + +@router.get("/{conversation_id}/messages", response_model=list[MessageDeviceCopyOutV2]) +async def list_messages( + conversation_id: str, + request: Request, + limit: int = Query(default=50, ge=1, le=200), + offset: int = Query(default=0, ge=0), + session: AsyncSession = Depends(db_session), + context: AuthenticatedDeviceContextV2 = Depends(get_current_device_context_v2), +) -> list[MessageDeviceCopyOutV2]: + await rate_limiter.enforce(client_rate_limit_key(request, f"v2-message-list:{context.user.id}")) + rows = await message_service_v2.list_messages_for_device( + session, + context.user, + context.device.id, + conversation_id, + limit=limit, + offset=offset, + ) + output: list[MessageDeviceCopyOutV2] = [] + for copy, event in rows: + output.append( + MessageDeviceCopyOutV2( + copy_id=copy.id, + message=MessageEventOutV2.model_validate(event), + recipient_device_uid=copy.recipient_device_uid, + envelope_json=copy.envelope_json, + status=copy.status, + created_at=copy.created_at, + ) + ) + return output diff --git a/server/app/api_v2/devices.py b/server/app/api_v2/devices.py new file mode 100644 index 0000000..9ee3198 --- /dev/null +++ b/server/app/api_v2/devices.py @@ -0,0 +1,58 @@ +from fastapi import APIRouter, Depends, HTTPException, Request +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.utils import client_rate_limit_key +from app.dependencies import AuthenticatedDeviceContextV2, db_session, get_bootstrap_user_v2, get_current_device_context_v2 +from app.models.user import User +from app.schemas.user import UserOut +from app.schemas.v2_auth import DeviceAuthResponseV2 +from app.schemas.v2_device import DeviceOutV2, DeviceRegisterRequestV2 +from app.services.auth_service import AuthServiceError +from app.services.auth_service_v2 import auth_service_v2 +from app.services.device_service_v2 import device_service_v2 +from app.services.rate_limit import rate_limiter + +router = APIRouter(prefix="/api/v2/devices", tags=["devices-v2"]) + + +@router.post("/register", response_model=DeviceAuthResponseV2) +async def register_device( + payload: DeviceRegisterRequestV2, + request: Request, + session: AsyncSession = Depends(db_session), + bootstrap_user: User = Depends(get_bootstrap_user_v2), +) -> DeviceAuthResponseV2: + await rate_limiter.enforce(client_rate_limit_key(request, "v2-device-register")) + try: + _, tokens = await auth_service_v2.register_device_from_bootstrap( + session, + bootstrap_user, + payload.label, + payload.pub_sign_key, + payload.pub_dh_key, + ) + return DeviceAuthResponseV2(user=UserOut.from_user(bootstrap_user), tokens=tokens) + except AuthServiceError as exc: + raise HTTPException(status_code=exc.status_code, detail=str(exc)) from exc + + +@router.get("", response_model=list[DeviceOutV2]) +async def list_devices( + request: Request, + session: AsyncSession = Depends(db_session), + context: AuthenticatedDeviceContextV2 = Depends(get_current_device_context_v2), +) -> list[DeviceOutV2]: + await rate_limiter.enforce(client_rate_limit_key(request, f"v2-device-list:{context.user.id}")) + return await device_service_v2.list_for_user(session, context.user.id) + + +@router.post("/{device_uid}/revoke", response_model=DeviceOutV2) +async def revoke_device( + device_uid: str, + request: Request, + session: AsyncSession = Depends(db_session), + context: AuthenticatedDeviceContextV2 = Depends(get_current_device_context_v2), +) -> DeviceOutV2: + await rate_limiter.enforce(client_rate_limit_key(request, f"v2-device-revoke:{context.user.id}")) + return await device_service_v2.revoke_device(session, context.user, device_uid) + diff --git a/server/app/api_v2/federation.py b/server/app/api_v2/federation.py new file mode 100644 index 0000000..0004a8a --- /dev/null +++ b/server/app/api_v2/federation.py @@ -0,0 +1,331 @@ +from fastapi import APIRouter, Depends, HTTPException, Path, Request, status +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.utils import client_rate_limit_key +from app.config import get_settings +from app.dependencies import db_session +from app.models.user import User +from app.schemas.v2_federation import ( + FederationCallWebRtcAnswerRequestV2, + FederationCallWebRtcIceRequestV2, + FederationCallWebRtcOfferRequestV2, + FederationGroupCallEndRequestV2, + FederationGroupCallJoinRequestV2, + FederationGroupCallLeaveRequestV2, + FederationGroupCallOfferRequestV2, + FederationGroupCallWebRtcAnswerRequestV2, + FederationGroupCallWebRtcIceRequestV2, + FederationGroupCallWebRtcOfferRequestV2, + FederationGroupEventRequestV2, + FederationGroupInviteAcceptRequestV2, + FederationMessageRelayRequestV2, + FederationGroupSnapshotOutV2, + FederationWellKnownOutV2, +) +from app.schemas.v2_device import DeviceOutV2, UserDeviceLookupV2 +from app.services.device_service_v2 import device_service_v2 +from app.services.call_service import CallProtocolError, call_service +from app.services.federation_security import federation_security_service +from app.services.group_call_service import group_call_service +from app.services.group_conversation_service import group_conversation_service +from app.services.message_service_v2 import message_service_v2 +from app.services.metrics import metrics +from app.services.prekey_service_v2 import prekey_service_v2 +from app.services.rate_limit import rate_limiter +from app.services.server_identity import ( + get_federation_signing_public_key_b64, + get_server_onion, + server_address_for_username, +) + +router = APIRouter(prefix="/api/v2/federation", tags=["federation-v2"]) + + +def _raise_for_call_error(exc: CallProtocolError) -> None: + conflict_codes = {"peer_busy", "peer_offline", "invalid_state", "call_not_ringing"} + not_found_codes = {"call_not_found", "conversation_not_found", "user_not_found"} + forbidden_codes = {"forbidden", "invalid_target"} + + if exc.code in conflict_codes: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=exc.detail) from exc + if exc.code in not_found_codes: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=exc.detail) from exc + if exc.code in forbidden_codes: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=exc.detail) from exc + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=exc.detail) from exc + + +async def _verify_federation_write_auth( + request: Request, + session: AsyncSession, +) -> bytes: + settings = get_settings() + await rate_limiter.enforce(client_rate_limit_key(request, "v2-federation-write"), limit=240) + raw_body = await request.body() + if len(raw_body) > settings.max_federation_body_bytes: + await metrics.inc("attachments.send.rejected_too_large") + raise HTTPException(status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, detail="Federation payload too large") + await federation_security_service.verify_incoming(session, request, raw_body) + sender = request.headers.get("x-bw-sender", "unknown").strip().lower() or "unknown" + try: + await rate_limiter.enforce_weighted( + f"v2-federation-bytes:{sender}", + units=len(raw_body), + limit=settings.federation_bytes_per_minute_per_peer, + ) + except HTTPException as exc: + if exc.status_code == status.HTTP_429_TOO_MANY_REQUESTS: + await metrics.inc("attachments.send.rejected_rate_limited") + raise + return raw_body + + +@router.get("/well-known", response_model=FederationWellKnownOutV2) +async def well_known() -> FederationWellKnownOutV2: + settings = get_settings() + supported_call_modes: list[str] = [] + if settings.enable_webrtc_v2b2: + supported_call_modes.append("webrtc_v0_2b2") + if settings.enable_legacy_call_audio_ws: + supported_call_modes.append("ws_pcm_v0_2a") + return FederationWellKnownOutV2( + server_onion=get_server_onion(), + federation_version="2", + signing_public_key=get_federation_signing_public_key_b64(), + identity_binding_mode="tor_v3_same_ed25519", + attachment_inline_max_bytes=settings.effective_attachment_inline_max_bytes(), + max_ciphertext_bytes=settings.effective_max_ciphertext_bytes(), + attachment_hard_ceiling_bytes=settings.attachment_hard_ceiling_bytes, + supported_message_modes=device_service_v2.supported_message_modes(), + supported_call_modes=supported_call_modes, + ) + + +@router.get("/users/{username}/devices", response_model=UserDeviceLookupV2) +async def get_local_user_devices( + request: Request, + username: str = Path(pattern=r"^[A-Za-z0-9_]{3,64}$"), + session: AsyncSession = Depends(db_session), +) -> UserDeviceLookupV2: + await rate_limiter.enforce(client_rate_limit_key(request, "v2-federation-device-lookup"), limit=120) + normalized = username.strip().lower() + user_stmt = select(User).where(User.username == normalized, User.disabled_at.is_(None)) + user = (await session.execute(user_stmt)).scalar_one_or_none() + if user is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") + + devices = await device_service_v2.list_active_for_user(session, user.id) + return UserDeviceLookupV2( + username=user.username, + peer_address=server_address_for_username(user.username), + devices=[ + DeviceOutV2( + device_uid=device.id, + user_id=device.user_id, + label=device.label, + pub_sign_key=device.ik_ed25519_pub, + pub_dh_key=device.enc_x25519_pub, + status=device.status, + supported_message_modes=device_service_v2.supported_message_modes(), + created_at=device.created_at, + last_seen_at=device.last_seen_at, + revoked_at=device.revoked_at, + ) + for device in devices + ], + attachment_inline_max_bytes=get_settings().effective_attachment_inline_max_bytes(), + max_ciphertext_bytes=get_settings().effective_max_ciphertext_bytes(), + attachment_policy_source="local", + ) + + +@router.get("/users/{username}/prekeys") +async def get_local_user_prekeys( + request: Request, + username: str = Path(pattern=r"^[A-Za-z0-9_]{3,64}$"), + session: AsyncSession = Depends(db_session), +) -> dict: + await rate_limiter.enforce(client_rate_limit_key(request, "v2-federation-prekey-lookup"), limit=120) + requested_by = request.headers.get("x-bw-sender", "federation@unknown").strip().lower() or "federation@unknown" + result = await prekey_service_v2.resolve_local_user_prekeys_for_federation(session, username, requested_by) + if result is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") + return result.model_dump() + + +@router.post("/messages/relay") +async def relay_message( + request: Request, + session: AsyncSession = Depends(db_session), +) -> dict[str, str]: + raw_body = await _verify_federation_write_auth(request, session) + payload = FederationMessageRelayRequestV2.model_validate_json(raw_body) + try: + await message_service_v2.relay_message_from_federation(session, payload) + except HTTPException as exc: + if exc.status_code == status.HTTP_413_REQUEST_ENTITY_TOO_LARGE: + await metrics.inc("attachments.send.rejected_too_large") + raise + return {"status": "ok"} + + +@router.post("/groups/events") +async def relay_group_event( + request: Request, + session: AsyncSession = Depends(db_session), +) -> dict[str, str]: + raw_body = await _verify_federation_write_auth(request, session) + payload = FederationGroupEventRequestV2.model_validate_json(raw_body) + sender_onion = request.headers.get("x-bw-sender", "").strip().lower() or None + await group_conversation_service.apply_federation_event( + session, + payload=payload, + sender_onion=sender_onion, + ) + return {"status": "ok"} + + +@router.get("/groups/{group_uid}/snapshot", response_model=FederationGroupSnapshotOutV2) +async def group_snapshot( + request: Request, + group_uid: str, + session: AsyncSession = Depends(db_session), +) -> FederationGroupSnapshotOutV2: + await rate_limiter.enforce(client_rate_limit_key(request, "v2-federation-group-snapshot"), limit=120) + await federation_security_service.verify_incoming(session, request, b"") + return await group_conversation_service.snapshot_for_group(session, group_uid) + + +@router.post("/groups/invites/accept") +async def group_invite_accept( + request: Request, + session: AsyncSession = Depends(db_session), +) -> dict[str, str]: + raw_body = await _verify_federation_write_auth(request, session) + payload = FederationGroupInviteAcceptRequestV2.model_validate_json(raw_body) + await group_conversation_service.accept_remote_invite_to_origin( + session, + group_uid=payload.group_uid, + actor_address=payload.actor_address, + ) + return {"status": "ok"} + + +@router.post("/group-calls/offer") +async def relay_group_call_offer( + request: Request, + session: AsyncSession = Depends(db_session), +) -> dict[str, str]: + raw_body = await _verify_federation_write_auth(request, session) + payload = FederationGroupCallOfferRequestV2.model_validate_json(raw_body) + await group_call_service.relay_offer(session, payload) + return {"status": "ok"} + + +@router.post("/group-calls/join") +async def relay_group_call_join( + request: Request, + session: AsyncSession = Depends(db_session), +) -> dict[str, str]: + raw_body = await _verify_federation_write_auth(request, session) + payload = FederationGroupCallJoinRequestV2.model_validate_json(raw_body) + await group_call_service.relay_join(session, payload) + return {"status": "ok"} + + +@router.post("/group-calls/leave") +async def relay_group_call_leave( + request: Request, + session: AsyncSession = Depends(db_session), +) -> dict[str, str]: + raw_body = await _verify_federation_write_auth(request, session) + payload = FederationGroupCallLeaveRequestV2.model_validate_json(raw_body) + await group_call_service.relay_leave(session, payload) + return {"status": "ok"} + + +@router.post("/group-calls/end") +async def relay_group_call_end( + request: Request, + session: AsyncSession = Depends(db_session), +) -> dict[str, str]: + raw_body = await _verify_federation_write_auth(request, session) + payload = FederationGroupCallEndRequestV2.model_validate_json(raw_body) + await group_call_service.relay_end(session, payload) + return {"status": "ok"} + + +@router.post("/group-calls/webrtc-offer") +async def relay_group_call_webrtc_offer( + request: Request, + session: AsyncSession = Depends(db_session), +) -> dict[str, str]: + raw_body = await _verify_federation_write_auth(request, session) + payload = FederationGroupCallWebRtcOfferRequestV2.model_validate_json(raw_body) + await group_call_service.relay_webrtc_offer(session, payload) + return {"status": "ok"} + + +@router.post("/group-calls/webrtc-answer") +async def relay_group_call_webrtc_answer( + request: Request, + session: AsyncSession = Depends(db_session), +) -> dict[str, str]: + raw_body = await _verify_federation_write_auth(request, session) + payload = FederationGroupCallWebRtcAnswerRequestV2.model_validate_json(raw_body) + await group_call_service.relay_webrtc_answer(session, payload) + return {"status": "ok"} + + +@router.post("/group-calls/webrtc-ice") +async def relay_group_call_webrtc_ice( + request: Request, + session: AsyncSession = Depends(db_session), +) -> dict[str, str]: + raw_body = await _verify_federation_write_auth(request, session) + payload = FederationGroupCallWebRtcIceRequestV2.model_validate_json(raw_body) + await group_call_service.relay_webrtc_ice(session, payload) + return {"status": "ok"} + + +@router.post("/calls/webrtc-offer") +async def relay_webrtc_offer( + request: Request, + session: AsyncSession = Depends(db_session), +) -> dict[str, str]: + raw_body = await _verify_federation_write_auth(request, session) + payload = FederationCallWebRtcOfferRequestV2.model_validate_json(raw_body) + try: + await call_service.relay_webrtc_offer(payload) + except CallProtocolError as exc: + _raise_for_call_error(exc) + return {"status": "ok"} + + +@router.post("/calls/webrtc-answer") +async def relay_webrtc_answer( + request: Request, + session: AsyncSession = Depends(db_session), +) -> dict[str, str]: + raw_body = await _verify_federation_write_auth(request, session) + payload = FederationCallWebRtcAnswerRequestV2.model_validate_json(raw_body) + try: + await call_service.relay_webrtc_answer(payload) + except CallProtocolError as exc: + _raise_for_call_error(exc) + return {"status": "ok"} + + +@router.post("/calls/webrtc-ice") +async def relay_webrtc_ice( + request: Request, + session: AsyncSession = Depends(db_session), +) -> dict[str, str]: + raw_body = await _verify_federation_write_auth(request, session) + payload = FederationCallWebRtcIceRequestV2.model_validate_json(raw_body) + try: + await call_service.relay_webrtc_ice(payload) + except CallProtocolError as exc: + _raise_for_call_error(exc) + return {"status": "ok"} diff --git a/server/app/api_v2/keys.py b/server/app/api_v2/keys.py new file mode 100644 index 0000000..781385e --- /dev/null +++ b/server/app/api_v2/keys.py @@ -0,0 +1,21 @@ +from fastapi import APIRouter, Depends, Request +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.utils import client_rate_limit_key +from app.dependencies import AuthenticatedDeviceContextV2, db_session, get_current_device_context_v2 +from app.schemas.v2_prekey import PrekeyUploadRequestV2, PrekeyUploadResponseV2 +from app.services.prekey_service_v2 import prekey_service_v2 +from app.services.rate_limit import rate_limiter + +router = APIRouter(prefix="/api/v2/keys", tags=["keys-v2"]) + + +@router.post("/prekeys/upload", response_model=PrekeyUploadResponseV2) +async def upload_prekeys( + payload: PrekeyUploadRequestV2, + request: Request, + session: AsyncSession = Depends(db_session), + context: AuthenticatedDeviceContextV2 = Depends(get_current_device_context_v2), +) -> PrekeyUploadResponseV2: + await rate_limiter.enforce(client_rate_limit_key(request, f"v2-prekeys-upload:{context.user.id}")) + return await prekey_service_v2.upload_prekeys(session, context.user, context.device, payload) diff --git a/server/app/api_v2/messages.py b/server/app/api_v2/messages.py new file mode 100644 index 0000000..3d58e3d --- /dev/null +++ b/server/app/api_v2/messages.py @@ -0,0 +1,57 @@ +from fastapi import APIRouter, Depends, HTTPException, Request, status +from pydantic import ValidationError +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.utils import client_rate_limit_key +from app.config import get_settings +from app.dependencies import AuthenticatedDeviceContextV2, db_session, get_current_device_context_v2 +from app.schemas.v2_message import MessageEventOutV2, MessageSendRequestV2, MessageSendResponseV2 +from app.services.message_service_v2 import message_service_v2 +from app.services.metrics import metrics +from app.services.rate_limit import rate_limiter + +router = APIRouter(prefix="/api/v2/messages", tags=["messages-v2"]) + + +@router.post("/send", response_model=MessageSendResponseV2) +async def send_message( + request: Request, + session: AsyncSession = Depends(db_session), + context: AuthenticatedDeviceContextV2 = Depends(get_current_device_context_v2), +) -> MessageSendResponseV2: + settings = get_settings() + await rate_limiter.enforce(client_rate_limit_key(request, f"v2-message-send:{context.user.id}")) + raw_body = await request.body() + if len(raw_body) > settings.max_client_message_body_bytes: + await metrics.inc("attachments.send.rejected_too_large") + raise HTTPException(status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, detail="Message payload too large") + try: + await rate_limiter.enforce_weighted( + f"v2-message-bytes:{context.user.id}", + units=len(raw_body), + limit=settings.message_bytes_per_minute_per_user, + ) + except HTTPException as exc: + if exc.status_code == status.HTTP_429_TOO_MANY_REQUESTS: + await metrics.inc("attachments.send.rejected_rate_limited") + raise + try: + payload = MessageSendRequestV2.model_validate_json(raw_body) + except ValidationError as exc: + raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=exc.errors()) from exc + try: + message, duplicate = await message_service_v2.send_message( + session, + context.user, + context.device, + payload, + request_authority=request.headers.get("host", "").strip().lower(), + ) + except HTTPException as exc: + if exc.status_code == status.HTTP_413_REQUEST_ENTITY_TOO_LARGE: + await metrics.inc("attachments.send.rejected_too_large") + raise + return MessageSendResponseV2( + duplicate=duplicate, + message=MessageEventOutV2.model_validate(message), + ) diff --git a/server/app/api_v2/presence.py b/server/app/api_v2/presence.py new file mode 100644 index 0000000..61c6505 --- /dev/null +++ b/server/app/api_v2/presence.py @@ -0,0 +1,43 @@ +from fastapi import APIRouter, Depends, Request +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.utils import client_rate_limit_key +from app.dependencies import AuthenticatedDeviceContextV2, db_session, get_current_device_context_v2 +from app.schemas.v2_presence import ( + PresencePeerOut, + PresenceResolveRequest, + PresenceResolveResponse, + PresenceSetRequest, + PresenceSetResponse, +) +from app.services.presence_service import presence_service +from app.services.rate_limit import rate_limiter + +router = APIRouter(tags=["presence-v2"]) + + +@router.post("/api/v2/presence/set", response_model=PresenceSetResponse) +async def set_presence_status( + payload: PresenceSetRequest, + request: Request, + context: AuthenticatedDeviceContextV2 = Depends(get_current_device_context_v2), +) -> PresenceSetResponse: + await rate_limiter.enforce(client_rate_limit_key(request, f"v2-presence-set:{context.user.id}")) + status_value = await presence_service.set_status(context.user.id, payload.status) + return PresenceSetResponse(status=status_value) + + +@router.post("/api/v2/presence/resolve", response_model=PresenceResolveResponse) +async def resolve_presence_status( + payload: PresenceResolveRequest, + request: Request, + session: AsyncSession = Depends(db_session), + context: AuthenticatedDeviceContextV2 = Depends(get_current_device_context_v2), +) -> PresenceResolveResponse: + await rate_limiter.enforce(client_rate_limit_key(request, f"v2-presence-resolve:{context.user.id}")) + peers = await presence_service.resolve_for_peer_addresses(session, payload.peer_addresses) + self_status = await presence_service.effective_status(context.user.id) + return PresenceResolveResponse( + self_status=self_status, + peers=[PresencePeerOut(peer_address=peer_address, status=status) for peer_address, status in peers], + ) diff --git a/server/app/api_v2/users.py b/server/app/api_v2/users.py new file mode 100644 index 0000000..bc8b248 --- /dev/null +++ b/server/app/api_v2/users.py @@ -0,0 +1,56 @@ +from fastapi import APIRouter, Depends, HTTPException, Query, Request, status +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.utils import client_rate_limit_key +from app.dependencies import AuthenticatedDeviceContextV2, db_session, get_current_device_context_v2 +from app.schemas.user import UserOut +from app.schemas.v2_device import UserDeviceLookupV2 +from app.schemas.v2_prekey import ResolvePrekeysResponseV2 +from app.services.device_service_v2 import device_service_v2 +from app.services.prekey_service_v2 import prekey_service_v2 +from app.services.rate_limit import rate_limiter + +router = APIRouter(tags=["users-v2"]) + + +@router.get("/api/v2/me", response_model=UserOut) +async def me( + request: Request, + context: AuthenticatedDeviceContextV2 = Depends(get_current_device_context_v2), +) -> UserOut: + await rate_limiter.enforce(client_rate_limit_key(request, f"v2-me:{context.user.id}")) + return UserOut.from_user(context.user) + + +@router.get("/api/v2/users/resolve-devices", response_model=UserDeviceLookupV2) +async def resolve_user_devices( + request: Request, + peer_address: str = Query(min_length=3, max_length=320), + session: AsyncSession = Depends(db_session), + context: AuthenticatedDeviceContextV2 = Depends(get_current_device_context_v2), +) -> UserDeviceLookupV2: + await rate_limiter.enforce(client_rate_limit_key(request, f"v2-user-devices:{context.user.id}")) + result = await device_service_v2.resolve_devices_by_peer_address( + session, + peer_address, + request_authority=request.headers.get("host", "").strip().lower(), + ) + if result is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Active devices not found") + return result + + +@router.get("/api/v2/users/resolve-prekeys", response_model=ResolvePrekeysResponseV2) +async def resolve_user_prekeys( + request: Request, + peer_address: str = Query(min_length=3, max_length=320), + session: AsyncSession = Depends(db_session), + context: AuthenticatedDeviceContextV2 = Depends(get_current_device_context_v2), +) -> ResolvePrekeysResponseV2: + await rate_limiter.enforce(client_rate_limit_key(request, f"v2-user-prekeys:{context.user.id}")) + return await prekey_service_v2.resolve_prekeys( + session, + context.user, + peer_address, + request_authority=request.headers.get("host", "").strip().lower(), + ) diff --git a/server/app/api_v2/ws.py b/server/app/api_v2/ws.py new file mode 100644 index 0000000..a6900be --- /dev/null +++ b/server/app/api_v2/ws.py @@ -0,0 +1,310 @@ +import logging + +from fastapi import HTTPException, WebSocket, WebSocketDisconnect +from pydantic import ValidationError +from sqlalchemy import select + +from app.db import get_session_factory +from app.models.device import Device +from app.models.user import User +from app.schemas.call import ( + CallAcceptRequest, + CallAudioRequest, + CallEndRequest, + CallOfferRequest, + CallRejectRequest, + CallWebRtcAnswerRequest, + CallWebRtcIceRequest, + CallWebRtcOfferRequest, +) +from app.security.tokens_v2 import TokenErrorV2, decode_token +from app.services.call_service import CallProtocolError, call_service +from app.services.group_call_service import group_call_service +from app.services.message_service_v2 import message_service_v2 +from app.services.conversation_service import conversation_service +from app.ws.manager import connection_manager + +logger = logging.getLogger("blackwire.app.v2.ws") +_BEARER_PREFIX = "bearer " + + +def _extract_bearer_token(websocket: WebSocket) -> str | None: + auth_header = websocket.headers.get("authorization", "").strip() + if auth_header.lower().startswith(_BEARER_PREFIX): + token = auth_header[len(_BEARER_PREFIX) :].strip() + if token: + return token + query_token = websocket.query_params.get("access_token", "").strip() + if query_token: + return query_token + return None + + +async def websocket_endpoint_v2(websocket: WebSocket) -> None: + token = _extract_bearer_token(websocket) + if not token: + await websocket.close(code=1008, reason="Missing bearer auth token") + return + + try: + payload = decode_token(token, expected_type="access") + except TokenErrorV2: + await websocket.close(code=1008, reason="Invalid access token") + return + + user_id = str(payload.get("sub") or "") + device_uid = str(payload.get("did") or "") + if not user_id or not device_uid: + await websocket.close(code=1008, reason="Invalid access token") + return + + session_factory = get_session_factory() + async with session_factory() as session: + user = (await session.execute(select(User).where(User.id == user_id, User.disabled_at.is_(None)))).scalar_one_or_none() + device = ( + await session.execute( + select(Device).where( + Device.id == device_uid, + Device.user_id == user_id, + Device.status == "active", + Device.revoked_at.is_(None), + ) + ) + ).scalar_one_or_none() + if user is None or device is None: + await websocket.close(code=1008, reason="Device revoked") + return + + await connection_manager.connect(user_id=user_id, websocket=websocket, device_uid=device_uid) + + async def send_call_error(code: str, detail: str) -> None: + await websocket.send_json({"type": "call.error", "code": code, "detail": detail}) + + async with session_factory() as initial_session: + await message_service_v2.drain_pending_for_websocket(initial_session, user_id, device_uid, websocket) + if group_call_service.settings.enable_group_call_v2c: + await group_call_service.replay_pending_for_user(initial_session, user_id) + + try: + while True: + incoming = await websocket.receive_json() + msg_type = incoming.get("type") + + if msg_type == "message.ack": + copy_id = incoming.get("copy_id") or incoming.get("message_id") + if not copy_id: + await websocket.send_json( + {"type": "error", "code": "invalid_ack", "detail": "copy_id required"} + ) + continue + async with session_factory() as ack_session: + await message_service_v2.acknowledge( + ack_session, + user_id=user_id, + device_uid=device_uid, + copy_id=copy_id, + ) + continue + + if msg_type == "call.offer": + try: + payload_call = CallOfferRequest.model_validate(incoming) + async with session_factory() as call_session: + conversation = await conversation_service.get_by_id(call_session, payload_call.conversation_id) + if conversation is not None and conversation.conversation_type == "group": + await group_call_service.offer(call_session, user_id, payload_call.conversation_id) + else: + await call_service.offer(call_session, user_id, payload_call) + except ValidationError as exc: + await send_call_error("invalid_call_offer", str(exc)) + except CallProtocolError as exc: + await send_call_error(exc.code, exc.detail) + except HTTPException as exc: + await send_call_error("invalid_call_offer", str(exc.detail)) + continue + + if msg_type == "call.accept": + try: + payload_call = CallAcceptRequest.model_validate(incoming) + async with session_factory() as call_session: + if await group_call_service.is_group_call(call_session, payload_call.call_id): + await group_call_service.join(call_session, user_id, payload_call.call_id) + else: + await call_service.accept(user_id, payload_call) + except ValidationError as exc: + await send_call_error("invalid_call_accept", str(exc)) + except CallProtocolError as exc: + await send_call_error(exc.code, exc.detail) + except HTTPException as exc: + await send_call_error("invalid_call_accept", str(exc.detail)) + continue + + if msg_type == "call.reject": + try: + payload_call = CallRejectRequest.model_validate(incoming) + async with session_factory() as call_session: + if await group_call_service.is_group_call(call_session, payload_call.call_id): + await group_call_service.reject(call_session, user_id, payload_call.call_id) + else: + await call_service.reject(user_id, payload_call) + except ValidationError as exc: + await send_call_error("invalid_call_reject", str(exc)) + except CallProtocolError as exc: + await send_call_error(exc.code, exc.detail) + except HTTPException as exc: + await send_call_error("invalid_call_reject", str(exc.detail)) + continue + + if msg_type == "call.end": + try: + payload_call = CallEndRequest.model_validate(incoming) + async with session_factory() as call_session: + if await group_call_service.is_group_call(call_session, payload_call.call_id): + await group_call_service.leave( + call_session, + user_id, + payload_call.call_id, + reason=(payload_call.reason or "left"), + ) + else: + await call_service.end(user_id, payload_call) + except ValidationError as exc: + await send_call_error("invalid_call_end", str(exc)) + except CallProtocolError as exc: + await send_call_error(exc.code, exc.detail) + except HTTPException as exc: + await send_call_error("invalid_call_end", str(exc.detail)) + continue + + if msg_type == "call.audio": + if not call_service.settings.enable_legacy_call_audio_ws: + await send_call_error("audio_deprecated", "WS audio transport is disabled; use WebRTC") + continue + try: + payload_call = CallAudioRequest.model_validate(incoming) + async with session_factory() as call_session: + if await group_call_service.is_group_call(call_session, payload_call.call_id): + await group_call_service.audio(call_session, user_id, payload_call) + else: + await call_service.audio(user_id, payload_call) + except ValidationError as exc: + await send_call_error("invalid_call_audio", str(exc)) + except CallProtocolError as exc: + await send_call_error(exc.code, exc.detail) + except HTTPException as exc: + await send_call_error("invalid_call_audio", str(exc.detail)) + continue + + if msg_type == "call.webrtc.offer": + try: + payload_call = CallWebRtcOfferRequest.model_validate(incoming) + async with session_factory() as call_session: + if await group_call_service.is_group_call(call_session, payload_call.call_id): + target = payload_call.target_user_address.strip().lower() + if not target: + raise HTTPException(status_code=400, detail="target_user_address is required for group signaling") + await group_call_service.route_webrtc_signal( + call_session, + sender_user_id=user_id, + call_id=payload_call.call_id, + event_type="call.webrtc.offer", + body={ + "sdp": payload_call.sdp, + "call_schema_version": payload_call.call_schema_version, + "call_mode": payload_call.call_mode, + "max_participants": payload_call.max_participants, + }, + target_user_address=target, + ) + else: + await call_service.webrtc_offer(user_id, payload_call) + except ValidationError as exc: + await send_call_error("invalid_call_webrtc_offer", str(exc)) + except CallProtocolError as exc: + await send_call_error(exc.code, exc.detail) + except HTTPException as exc: + await send_call_error("invalid_call_webrtc_offer", str(exc.detail)) + continue + + if msg_type == "call.webrtc.answer": + try: + payload_call = CallWebRtcAnswerRequest.model_validate(incoming) + async with session_factory() as call_session: + if await group_call_service.is_group_call(call_session, payload_call.call_id): + target = payload_call.target_user_address.strip().lower() + if not target: + raise HTTPException(status_code=400, detail="target_user_address is required for group signaling") + await group_call_service.route_webrtc_signal( + call_session, + sender_user_id=user_id, + call_id=payload_call.call_id, + event_type="call.webrtc.answer", + body={ + "sdp": payload_call.sdp, + "call_schema_version": payload_call.call_schema_version, + "call_mode": payload_call.call_mode, + "max_participants": payload_call.max_participants, + }, + target_user_address=target, + ) + else: + await call_service.webrtc_answer(user_id, payload_call) + except ValidationError as exc: + await send_call_error("invalid_call_webrtc_answer", str(exc)) + except CallProtocolError as exc: + await send_call_error(exc.code, exc.detail) + except HTTPException as exc: + await send_call_error("invalid_call_webrtc_answer", str(exc.detail)) + continue + + if msg_type == "call.webrtc.ice": + try: + payload_call = CallWebRtcIceRequest.model_validate(incoming) + async with session_factory() as call_session: + if await group_call_service.is_group_call(call_session, payload_call.call_id): + target = payload_call.target_user_address.strip().lower() + if not target: + raise HTTPException(status_code=400, detail="target_user_address is required for group signaling") + await group_call_service.route_webrtc_signal( + call_session, + sender_user_id=user_id, + call_id=payload_call.call_id, + event_type="call.webrtc.ice", + body={ + "candidate": payload_call.candidate, + "sdp_mid": payload_call.sdp_mid, + "sdp_mline_index": payload_call.sdp_mline_index, + "call_schema_version": payload_call.call_schema_version, + "call_mode": payload_call.call_mode, + "max_participants": payload_call.max_participants, + }, + target_user_address=target, + ) + else: + await call_service.webrtc_ice(user_id, payload_call) + except ValidationError as exc: + await send_call_error("invalid_call_webrtc_ice", str(exc)) + except CallProtocolError as exc: + await send_call_error(exc.code, exc.detail) + except HTTPException as exc: + await send_call_error("invalid_call_webrtc_ice", str(exc.detail)) + continue + + await websocket.send_json( + { + "type": "error", + "code": "unsupported_event", + "detail": ( + "Supported client events: " + "message.ack, call.offer, call.accept, call.reject, call.end, " + "call.audio, call.webrtc.offer, call.webrtc.answer, call.webrtc.ice" + ), + } + ) + except WebSocketDisconnect: + logger.info("websocket_disconnect", extra={"user_id": user_id, "device_uid": device_uid}) + finally: + async with session_factory() as disconnect_session: + await group_call_service.handle_disconnect(disconnect_session, user_id) + await call_service.handle_disconnect(user_id) + await connection_manager.disconnect(user_id=user_id, websocket=websocket, device_uid=device_uid) diff --git a/server/app/config.py b/server/app/config.py index 31ad531..979f072 100644 --- a/server/app/config.py +++ b/server/app/config.py @@ -2,6 +2,7 @@ from typing import Literal from pydantic import Field +from pydantic import model_validator from pydantic_settings import BaseSettings, SettingsConfigDict @@ -19,22 +20,52 @@ class Settings(BaseSettings): jwt_algorithm: str = "HS256" access_token_minutes: int = 15 refresh_token_days: int = 30 + jwt_private_key_pem: str = "" + jwt_public_key_pem: str = "" + v2_bootstrap_token_seconds: int = 120 + v2_access_token_minutes: int = 10 + v2_refresh_token_days: int = 30 + v2_bind_request_skew_seconds: int = 300 message_ttl_days: int = 7 queue_cleanup_interval_seconds: int = 60 - max_ciphertext_bytes: int = 65536 + max_ciphertext_bytes: int = 20971520 max_aad_bytes: int = 4096 + attachment_inline_max_bytes: int = 10485760 + attachment_hard_ceiling_bytes: int = 33554432 + max_client_message_body_bytes: int = 31457280 + max_federation_body_bytes: int = 31457280 + message_bytes_per_minute_per_user: int = 52428800 + federation_bytes_per_minute_per_peer: int = 104857600 + pending_queue_max_copies_per_device: int = 500 + pending_queue_max_bytes_per_device: int = 67108864 + enable_ratchet_v2b1: bool = False + ratchet_require_for_local: bool = False + ratchet_require_for_federation: bool = False + enable_webrtc_v2b2: bool = False + webrtc_ice_servers_json: str = "" + enable_legacy_call_audio_ws: bool = True voice_call_ring_timeout_seconds: int = 30 voice_audio_max_chunk_bytes: int = 4096 voice_audio_min_interval_ms: int = 8 + enable_group_dm_v2c: bool = False + enable_group_call_v2c: bool = False + group_max_members: int = 32 + group_call_max_participants: int = 8 + group_invite_rate_per_minute: int = 120 + group_create_rate_per_hour: int = 30 + group_call_start_rate_per_minute: int = 30 + group_call_ring_ttl_seconds: int = 45 tor_enabled: bool = False tor_socks5_url: str = "socks5h://127.0.0.1:9050" tor_hs_hostname_file: str = "/var/lib/tor/hidden_service/hostname" + tor_hs_ed25519_secret_key_file: str = "/var/lib/tor/hidden_service/hs_ed25519_secret_key" tor_hs_hostname_wait_seconds: int = 45 federation_server_onion: str = "local.invalid" + local_server_aliases: str = "" federation_signing_private_key_b64: str = "" federation_request_skew_seconds: int = 60 federation_nonce_ttl_seconds: int = 300 @@ -46,6 +77,56 @@ class Settings(BaseSettings): allow_origins: list[str] = Field(default_factory=lambda: []) + def effective_attachment_inline_max_bytes(self) -> int: + return min(self.attachment_inline_max_bytes, self.attachment_hard_ceiling_bytes) + + def effective_max_ciphertext_bytes(self) -> int: + return min(self.max_ciphertext_bytes, self.attachment_hard_ceiling_bytes) + + @model_validator(mode="after") + def validate_attachment_limits(self) -> "Settings": + if self.attachment_hard_ceiling_bytes <= 0: + raise ValueError("attachment_hard_ceiling_bytes must be greater than 0") + if self.attachment_inline_max_bytes <= 0: + raise ValueError("attachment_inline_max_bytes must be greater than 0") + if self.max_ciphertext_bytes <= 0: + raise ValueError("max_ciphertext_bytes must be greater than 0") + if self.max_client_message_body_bytes <= 0: + raise ValueError("max_client_message_body_bytes must be greater than 0") + if self.max_federation_body_bytes <= 0: + raise ValueError("max_federation_body_bytes must be greater than 0") + if self.message_bytes_per_minute_per_user <= 0: + raise ValueError("message_bytes_per_minute_per_user must be greater than 0") + if self.federation_bytes_per_minute_per_peer <= 0: + raise ValueError("federation_bytes_per_minute_per_peer must be greater than 0") + if self.pending_queue_max_copies_per_device <= 0: + raise ValueError("pending_queue_max_copies_per_device must be greater than 0") + if self.pending_queue_max_bytes_per_device <= 0: + raise ValueError("pending_queue_max_bytes_per_device must be greater than 0") + if self.group_max_members < 2: + raise ValueError("group_max_members must be at least 2") + if self.group_call_max_participants < 2: + raise ValueError("group_call_max_participants must be at least 2") + if self.group_call_max_participants > self.group_max_members: + raise ValueError("group_call_max_participants cannot exceed group_max_members") + if self.group_invite_rate_per_minute <= 0: + raise ValueError("group_invite_rate_per_minute must be greater than 0") + if self.group_create_rate_per_hour <= 0: + raise ValueError("group_create_rate_per_hour must be greater than 0") + if self.group_call_start_rate_per_minute <= 0: + raise ValueError("group_call_start_rate_per_minute must be greater than 0") + if self.group_call_ring_ttl_seconds <= 0: + raise ValueError("group_call_ring_ttl_seconds must be greater than 0") + if self.attachment_inline_max_bytes > self.attachment_hard_ceiling_bytes: + raise ValueError("attachment_inline_max_bytes cannot exceed attachment_hard_ceiling_bytes") + if self.max_ciphertext_bytes > self.attachment_hard_ceiling_bytes: + raise ValueError("max_ciphertext_bytes cannot exceed attachment_hard_ceiling_bytes") + if self.max_client_message_body_bytes < self.effective_max_ciphertext_bytes(): + raise ValueError("max_client_message_body_bytes cannot be smaller than max_ciphertext_bytes") + if self.max_federation_body_bytes < self.effective_max_ciphertext_bytes(): + raise ValueError("max_federation_body_bytes cannot be smaller than max_ciphertext_bytes") + return self + @lru_cache def get_settings() -> Settings: diff --git a/server/app/dependencies.py b/server/app/dependencies.py index d21bbd3..6b0133b 100644 --- a/server/app/dependencies.py +++ b/server/app/dependencies.py @@ -1,4 +1,5 @@ from collections.abc import AsyncGenerator +from dataclasses import dataclass from fastapi import Depends, HTTPException, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer @@ -6,10 +7,19 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.db import get_db_session +from app.models.device import Device from app.models.user import User from app.security.tokens import TokenError, decode_token +from app.security.tokens_v2 import TokenErrorV2, decode_token as decode_token_v2 bearer_scheme = HTTPBearer(auto_error=False) +bootstrap_bearer_scheme = HTTPBearer(auto_error=False) + + +@dataclass(slots=True) +class AuthenticatedDeviceContextV2: + user: User + device: Device async def db_session() -> AsyncGenerator[AsyncSession, None]: @@ -40,3 +50,61 @@ async def get_current_user( if user is None: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found") return user + + +async def get_bootstrap_user_v2( + credentials: HTTPAuthorizationCredentials | None = Depends(bootstrap_bearer_scheme), + session: AsyncSession = Depends(db_session), +) -> User: + if credentials is None: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing auth token") + + token = credentials.credentials + try: + payload = decode_token_v2(token, expected_type="bootstrap") + except TokenErrorV2 as exc: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc)) from exc + + user_id = payload.get("sub") + if not user_id: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid bootstrap token") + + stmt = select(User).where(User.id == user_id, User.disabled_at.is_(None)) + result = await session.execute(stmt) + user = result.scalar_one_or_none() + if user is None: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found") + return user + + +async def get_current_device_context_v2( + credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme), + session: AsyncSession = Depends(db_session), +) -> AuthenticatedDeviceContextV2: + if credentials is None: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing auth token") + + token = credentials.credentials + try: + payload = decode_token_v2(token, expected_type="access") + except TokenErrorV2 as exc: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc)) from exc + + user_id = payload.get("sub") + device_uid = payload.get("did") + if not user_id or not device_uid: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid access token") + + user_stmt = select(User).where(User.id == user_id, User.disabled_at.is_(None)) + user = (await session.execute(user_stmt)).scalar_one_or_none() + if user is None: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found") + + device_stmt = select(Device).where(Device.id == str(device_uid), Device.user_id == user.id) + device = (await session.execute(device_stmt)).scalar_one_or_none() + if device is None: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Device not found") + if device.status != "active" or device.revoked_at is not None: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Device revoked") + + return AuthenticatedDeviceContextV2(user=user, device=device) diff --git a/server/app/main.py b/server/app/main.py index b9bbe23..bb42af6 100644 --- a/server/app/main.py +++ b/server/app/main.py @@ -1,4 +1,5 @@ import asyncio +import json import logging from fastapi import FastAPI, WebSocket, WebSocketDisconnect @@ -6,6 +7,17 @@ from pydantic import ValidationError from app.api import auth, conversations, devices, federation, health, messages, metrics, users +from app.api_v2 import ( + auth as auth_v2, + conversations as conversations_v2, + devices as devices_v2, + federation as federation_v2, + keys as keys_v2, + messages as messages_v2, + presence as presence_v2, + users as users_v2, + ws as ws_v2, +) from app.config import get_settings from app.db import get_session_factory, init_engine, init_models from app.logging_utils import RequestContextMiddleware, configure_logging @@ -20,6 +32,7 @@ from app.services.call_service import CallProtocolError, call_service from app.services.federation_outbox_service import federation_outbox_service from app.services.message_service import message_service +from app.services.message_service_v2 import message_service_v2 from app.services.queue_worker import queue_cleanup_worker from app.services.rate_limit import rate_limiter from app.services.server_identity import ( @@ -68,6 +81,14 @@ def create_app() -> FastAPI: app.include_router(messages.router, prefix=settings.api_prefix) app.include_router(federation.router, prefix=settings.api_prefix) app.include_router(metrics.router, prefix=settings.api_prefix) + app.include_router(auth_v2.router) + app.include_router(devices_v2.router) + app.include_router(keys_v2.router) + app.include_router(users_v2.router) + app.include_router(presence_v2.router) + app.include_router(conversations_v2.router) + app.include_router(messages_v2.router) + app.include_router(federation_v2.router) @app.on_event("startup") async def on_startup() -> None: @@ -75,6 +96,17 @@ async def on_startup() -> None: raise RuntimeError("BLACKWIRE_JWT_SECRET_KEY must be at least 32 bytes") if settings.environment != "dev" and "*" in settings.allow_origins: raise RuntimeError("Wildcard CORS origin is not allowed outside dev") + if settings.enable_webrtc_v2b2 and not settings.webrtc_ice_servers_json.strip(): + raise RuntimeError( + "BLACKWIRE_WEBRTC_ICE_SERVERS_JSON is required when BLACKWIRE_ENABLE_WEBRTC_V2B2=true" + ) + if settings.enable_webrtc_v2b2 and settings.webrtc_ice_servers_json.strip(): + try: + parsed_ice = json.loads(settings.webrtc_ice_servers_json) + except json.JSONDecodeError as exc: + raise RuntimeError("BLACKWIRE_WEBRTC_ICE_SERVERS_JSON must be valid JSON") from exc + if not isinstance(parsed_ice, list) or not parsed_ice: + raise RuntimeError("BLACKWIRE_WEBRTC_ICE_SERVERS_JSON must be a non-empty JSON array") initialize_server_identity(settings) server_onion = get_server_onion() @@ -96,7 +128,9 @@ async def on_startup() -> None: async def cleanup_once() -> int: session_factory = get_session_factory() async with session_factory() as session: - return await message_service.expire_old(session) + expired_v1 = await message_service.expire_old(session) + expired_v2 = await message_service_v2.expire_old(session) + return expired_v1 + expired_v2 async def federation_outbox_once() -> int: session_factory = get_session_factory() @@ -250,6 +284,10 @@ async def send_call_error(code: str, detail: str) -> None: await call_service.handle_disconnect(user_id) await connection_manager.disconnect(user_id=user_id, websocket=websocket) + @app.websocket("/api/v2/ws") + async def websocket_endpoint_v2(websocket: WebSocket) -> None: + await ws_v2.websocket_endpoint_v2(websocket) + return app diff --git a/server/app/models/__init__.py b/server/app/models/__init__.py index fe8909e..95a27c8 100644 --- a/server/app/models/__init__.py +++ b/server/app/models/__init__.py @@ -1,10 +1,20 @@ from app.models.conversation import Conversation +from app.models.conversation_member import ConversationMember from app.models.delivery_queue import DeliveryQueue from app.models.device import ActiveDevice, Device from app.models.federation_nonce_replay import FederationNonceReplay from app.models.federation_outbox import FederationOutbox from app.models.federation_peer import FederationPeer +from app.models.group_call_participant import GroupCallParticipant +from app.models.group_call_session import GroupCallSession +from app.models.group_membership_event import GroupMembershipEvent from app.models.message import Message +from app.models.message_device_copy import MessageDeviceCopy +from app.models.message_event import MessageEvent +from app.models.device_signed_prekey import DeviceSignedPrekey +from app.models.device_one_time_prekey import DeviceOneTimePrekey +from app.models.ratchet_session import RatchetSession +from app.models.ratchet_skipped_key import RatchetSkippedKey from app.models.refresh_token import RefreshToken from app.models.user import User @@ -13,7 +23,17 @@ "Device", "ActiveDevice", "Conversation", + "ConversationMember", + "GroupMembershipEvent", + "GroupCallSession", + "GroupCallParticipant", "Message", + "MessageEvent", + "MessageDeviceCopy", + "DeviceSignedPrekey", + "DeviceOneTimePrekey", + "RatchetSession", + "RatchetSkippedKey", "DeliveryQueue", "FederationPeer", "FederationNonceReplay", diff --git a/server/app/models/conversation.py b/server/app/models/conversation.py index ea76146..d21bc35 100644 --- a/server/app/models/conversation.py +++ b/server/app/models/conversation.py @@ -30,6 +30,11 @@ class Conversation(Base): peer_username: Mapped[str] = mapped_column(String(64), default="") peer_server_onion: Mapped[str] = mapped_column(String(255), default="") peer_address: Mapped[str] = mapped_column(String(320), default="", index=True) + conversation_type: Mapped[str] = mapped_column(String(16), default="direct", index=True) + group_uid: Mapped[str | None] = mapped_column(String(64), nullable=True, unique=True, index=True) + group_name: Mapped[str] = mapped_column(String(128), default="") + origin_server_onion: Mapped[str] = mapped_column(String(255), default="") + owner_address: Mapped[str] = mapped_column(String(320), default="") created_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), default=lambda: datetime.now(UTC) diff --git a/server/app/models/conversation_member.py b/server/app/models/conversation_member.py new file mode 100644 index 0000000..120684a --- /dev/null +++ b/server/app/models/conversation_member.py @@ -0,0 +1,37 @@ +import uuid +from datetime import UTC, datetime + +from sqlalchemy import DateTime, ForeignKey, String, UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column + +from app.models.base import Base + + +class ConversationMember(Base): + __tablename__ = "conversation_members" + __table_args__ = ( + UniqueConstraint("conversation_id", "member_address", name="uq_conversation_member_address"), + ) + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + conversation_id: Mapped[str] = mapped_column( + ForeignKey("conversations.id", ondelete="CASCADE"), + index=True, + ) + member_user_id: Mapped[str | None] = mapped_column( + ForeignKey("users.id", ondelete="SET NULL"), + index=True, + nullable=True, + ) + member_address: Mapped[str] = mapped_column(String(320), index=True) + member_server_onion: Mapped[str] = mapped_column(String(255), index=True) + role: Mapped[str] = mapped_column(String(16), default="member", index=True) + status: Mapped[str] = mapped_column(String(16), default="invited", index=True) + invited_by_address: Mapped[str] = mapped_column(String(320), default="") + invited_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC)) + joined_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + left_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + default=lambda: datetime.now(UTC), + ) diff --git a/server/app/models/device.py b/server/app/models/device.py index 4b2f285..ca06814 100644 --- a/server/app/models/device.py +++ b/server/app/models/device.py @@ -15,9 +15,13 @@ class Device(Base): label: Mapped[str] = mapped_column(String(100)) ik_ed25519_pub: Mapped[str] = mapped_column(String(256)) enc_x25519_pub: Mapped[str] = mapped_column(String(256), unique=True) + status: Mapped[str] = mapped_column(String(16), index=True, default="active") created_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), default=lambda: datetime.now(UTC) ) + last_seen_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), default=lambda: datetime.now(UTC) + ) revoked_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) user = relationship("User", back_populates="devices") diff --git a/server/app/models/device_one_time_prekey.py b/server/app/models/device_one_time_prekey.py new file mode 100644 index 0000000..db078dd --- /dev/null +++ b/server/app/models/device_one_time_prekey.py @@ -0,0 +1,22 @@ +import uuid +from datetime import UTC, datetime + +from sqlalchemy import DateTime, ForeignKey, Integer, String, UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column + +from app.models.base import Base + + +class DeviceOneTimePrekey(Base): + __tablename__ = "device_one_time_prekeys" + __table_args__ = ( + UniqueConstraint("device_id", "key_id", name="uq_device_one_time_prekey_device_key"), + ) + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + device_id: Mapped[str] = mapped_column(ForeignKey("devices.id", ondelete="CASCADE"), index=True) + key_id: Mapped[int] = mapped_column(Integer) + pub_x25519_b64: Mapped[str] = mapped_column(String(128)) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC)) + consumed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), index=True, nullable=True) + consumed_by_address: Mapped[str | None] = mapped_column(String(320), nullable=True) diff --git a/server/app/models/device_signed_prekey.py b/server/app/models/device_signed_prekey.py new file mode 100644 index 0000000..2d73eb0 --- /dev/null +++ b/server/app/models/device_signed_prekey.py @@ -0,0 +1,25 @@ +import uuid +from datetime import UTC, datetime, timedelta + +from sqlalchemy import DateTime, ForeignKey, Integer, String, UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column + +from app.models.base import Base + + +class DeviceSignedPrekey(Base): + __tablename__ = "device_signed_prekeys" + __table_args__ = ( + UniqueConstraint("device_id", "key_id", name="uq_device_signed_prekey_device_key"), + ) + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + device_id: Mapped[str] = mapped_column(ForeignKey("devices.id", ondelete="CASCADE"), index=True) + key_id: Mapped[int] = mapped_column(Integer) + pub_x25519_b64: Mapped[str] = mapped_column(String(128)) + sig_by_device_sign_key_b64: Mapped[str] = mapped_column(String(256)) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC)) + expires_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), default=lambda: datetime.now(UTC) + timedelta(days=30) + ) + revoked_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) diff --git a/server/app/models/group_call_participant.py b/server/app/models/group_call_participant.py new file mode 100644 index 0000000..a0360e0 --- /dev/null +++ b/server/app/models/group_call_participant.py @@ -0,0 +1,31 @@ +import uuid +from datetime import UTC, datetime + +from sqlalchemy import DateTime, ForeignKey, String, UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column + +from app.models.base import Base + + +class GroupCallParticipant(Base): + __tablename__ = "group_call_participants" + __table_args__ = ( + UniqueConstraint("call_id", "member_address", name="uq_group_call_participant_address"), + ) + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + call_id: Mapped[str] = mapped_column( + ForeignKey("group_call_sessions.call_id", ondelete="CASCADE"), + index=True, + ) + member_address: Mapped[str] = mapped_column(String(320), index=True) + local_user_id: Mapped[str | None] = mapped_column( + ForeignKey("users.id", ondelete="SET NULL"), + index=True, + nullable=True, + ) + state: Mapped[str] = mapped_column(String(24), default="invited", index=True) + invited_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC)) + joined_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + left_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + last_signal_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) diff --git a/server/app/models/group_call_session.py b/server/app/models/group_call_session.py new file mode 100644 index 0000000..9e6ac2b --- /dev/null +++ b/server/app/models/group_call_session.py @@ -0,0 +1,24 @@ +import uuid +from datetime import UTC, datetime + +from sqlalchemy import DateTime, ForeignKey, String +from sqlalchemy.orm import Mapped, mapped_column + +from app.models.base import Base + + +class GroupCallSession(Base): + __tablename__ = "group_call_sessions" + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + call_id: Mapped[str] = mapped_column(String(36), unique=True, index=True) + conversation_id: Mapped[str] = mapped_column( + ForeignKey("conversations.id", ondelete="CASCADE"), + index=True, + ) + group_uid: Mapped[str] = mapped_column(String(64), index=True) + initiator_address: Mapped[str] = mapped_column(String(320), index=True) + state: Mapped[str] = mapped_column(String(16), default="ringing", index=True) + started_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC)) + ended_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + ring_expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), index=True) diff --git a/server/app/models/group_membership_event.py b/server/app/models/group_membership_event.py new file mode 100644 index 0000000..2bd15bf --- /dev/null +++ b/server/app/models/group_membership_event.py @@ -0,0 +1,24 @@ +import uuid +from datetime import UTC, datetime +from typing import Any + +from sqlalchemy import JSON, DateTime, Integer, String, UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column + +from app.models.base import Base + + +class GroupMembershipEvent(Base): + __tablename__ = "group_membership_events" + __table_args__ = ( + UniqueConstraint("group_uid", "event_seq", name="uq_group_membership_event_seq"), + ) + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + group_uid: Mapped[str] = mapped_column(String(64), index=True) + event_seq: Mapped[int] = mapped_column(Integer, index=True) + event_type: Mapped[str] = mapped_column(String(24), index=True) + actor_address: Mapped[str] = mapped_column(String(320), default="") + target_address: Mapped[str] = mapped_column(String(320), default="") + payload_json: Mapped[dict[str, Any]] = mapped_column(JSON, default=dict) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC)) diff --git a/server/app/models/message_device_copy.py b/server/app/models/message_device_copy.py new file mode 100644 index 0000000..8676ac3 --- /dev/null +++ b/server/app/models/message_device_copy.py @@ -0,0 +1,29 @@ +import uuid +from datetime import UTC, datetime +from typing import Any + +from sqlalchemy import JSON, DateTime, ForeignKey, Integer, String, UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column + +from app.models.base import Base + + +class MessageDeviceCopy(Base): + __tablename__ = "message_device_copies" + __table_args__ = ( + UniqueConstraint("message_event_id", "recipient_device_uid", name="uq_message_copy_event_recipient_device"), + ) + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + message_event_id: Mapped[str] = mapped_column(ForeignKey("message_events.id", ondelete="CASCADE"), index=True) + recipient_user_id: Mapped[str] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True) + recipient_device_uid: Mapped[str] = mapped_column(String(36), index=True) + envelope_json: Mapped[dict[str, Any]] = mapped_column(JSON) + status: Mapped[str] = mapped_column(String(20), index=True, default="pending") + available_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC)) + expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), index=True) + delivered_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + attempt_count: Mapped[int] = mapped_column(Integer, default=0) + last_attempt_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC)) + diff --git a/server/app/models/message_event.py b/server/app/models/message_event.py new file mode 100644 index 0000000..ed8e7a1 --- /dev/null +++ b/server/app/models/message_event.py @@ -0,0 +1,31 @@ +import uuid +from datetime import UTC, datetime + +from sqlalchemy import BigInteger, DateTime, ForeignKey, String, UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column + +from app.models.base import Base + + +class MessageEvent(Base): + __tablename__ = "message_events" + __table_args__ = ( + UniqueConstraint("sender_device_uid", "client_message_id", name="uq_message_event_sender_client_message"), + ) + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + conversation_id: Mapped[str] = mapped_column( + ForeignKey("conversations.id", ondelete="CASCADE"), index=True + ) + sender_user_id: Mapped[str | None] = mapped_column( + ForeignKey("users.id", ondelete="SET NULL"), index=True, nullable=True + ) + sender_address: Mapped[str] = mapped_column(String(320), index=True, default="") + sender_device_uid: Mapped[str] = mapped_column(String(36), index=True) + sender_device_pubkey: Mapped[str] = mapped_column(String(256)) + client_message_id: Mapped[str] = mapped_column(String(64), index=True) + sent_at_ms: Mapped[int] = mapped_column(BigInteger, index=True) + encryption_mode: Mapped[str] = mapped_column(String(32), index=True, default="sealedbox_v0_2a") + sender_prev_hash: Mapped[str] = mapped_column(String(128), default="") + sender_chain_hash: Mapped[str] = mapped_column(String(128), index=True) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC)) diff --git a/server/app/models/ratchet_session.py b/server/app/models/ratchet_session.py new file mode 100644 index 0000000..3265b43 --- /dev/null +++ b/server/app/models/ratchet_session.py @@ -0,0 +1,33 @@ +import uuid +from datetime import UTC, datetime + +from sqlalchemy import DateTime, ForeignKey, Integer, String, UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column + +from app.models.base import Base + + +class RatchetSession(Base): + __tablename__ = "ratchet_sessions" + __table_args__ = ( + UniqueConstraint( + "owner_device_uid", + "peer_address", + "peer_device_uid", + name="uq_ratchet_session_owner_peer", + ), + ) + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + owner_user_id: Mapped[str] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True) + owner_device_uid: Mapped[str] = mapped_column(ForeignKey("devices.id", ondelete="CASCADE"), index=True) + peer_address: Mapped[str] = mapped_column(String(320), index=True) + peer_device_uid: Mapped[str] = mapped_column(String(64), index=True) + session_version: Mapped[str] = mapped_column(String(16), default="dr_v1") + state_blob_encrypted_b64: Mapped[str] = mapped_column(String(65536), default="") + state_nonce_b64: Mapped[str] = mapped_column(String(128), default="") + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC)) + updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC)) + last_send_chain_n: Mapped[int] = mapped_column(Integer, default=0) + last_recv_chain_n: Mapped[int] = mapped_column(Integer, default=0) + last_root_key_hash: Mapped[str] = mapped_column(String(128), default="") diff --git a/server/app/models/ratchet_skipped_key.py b/server/app/models/ratchet_skipped_key.py new file mode 100644 index 0000000..e521df6 --- /dev/null +++ b/server/app/models/ratchet_skipped_key.py @@ -0,0 +1,34 @@ +import uuid +from datetime import UTC, datetime, timedelta + +from sqlalchemy import DateTime, ForeignKey, Integer, String, UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column + +from app.models.base import Base + + +class RatchetSkippedKey(Base): + __tablename__ = "ratchet_skipped_keys" + __table_args__ = ( + UniqueConstraint( + "owner_device_uid", + "peer_address", + "peer_device_uid", + "dh_pub_b64", + "msg_n", + name="uq_ratchet_skipped_owner_peer_dh_n", + ), + ) + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + owner_device_uid: Mapped[str] = mapped_column(ForeignKey("devices.id", ondelete="CASCADE"), index=True) + peer_address: Mapped[str] = mapped_column(String(320), index=True) + peer_device_uid: Mapped[str] = mapped_column(String(64), index=True) + dh_pub_b64: Mapped[str] = mapped_column(String(128)) + msg_n: Mapped[int] = mapped_column(Integer) + mk_encrypted_b64: Mapped[str] = mapped_column(String(512)) + mk_nonce_b64: Mapped[str] = mapped_column(String(128)) + expires_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), default=lambda: datetime.now(UTC) + timedelta(days=7) + ) + used_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) diff --git a/server/app/models/refresh_token.py b/server/app/models/refresh_token.py index 0df8ebd..7c782bb 100644 --- a/server/app/models/refresh_token.py +++ b/server/app/models/refresh_token.py @@ -12,7 +12,11 @@ class RefreshToken(Base): id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) user_id: Mapped[str] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True) + device_id: Mapped[str | None] = mapped_column( + ForeignKey("devices.id", ondelete="CASCADE"), index=True, nullable=True + ) token_hash: Mapped[str] = mapped_column(String(128), unique=True, index=True) + token_kind: Mapped[str] = mapped_column(String(16), default="refresh", index=True) issued_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC)) expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), index=True) revoked_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) diff --git a/server/app/schemas/call.py b/server/app/schemas/call.py index 6608f5f..f23f00b 100644 --- a/server/app/schemas/call.py +++ b/server/app/schemas/call.py @@ -35,3 +35,35 @@ def validate_pcm_b64(cls, value: str) -> str: except binascii.Error as exc: raise ValueError("pcm_b64 must be valid base64") from exc return value + + +class CallWebRtcOfferRequest(BaseModel): + call_id: str + sdp: str = Field(min_length=8, max_length=131072) + call_schema_version: int = Field(default=1, ge=1) + call_mode: str = Field(default="webrtc", max_length=32) + max_participants: int = Field(default=2, ge=2, le=8) + target_user_address: str = Field(default="", max_length=320) + source_user_address: str = Field(default="", max_length=320) + + +class CallWebRtcAnswerRequest(BaseModel): + call_id: str + sdp: str = Field(min_length=8, max_length=131072) + call_schema_version: int = Field(default=1, ge=1) + call_mode: str = Field(default="webrtc", max_length=32) + max_participants: int = Field(default=2, ge=2, le=8) + target_user_address: str = Field(default="", max_length=320) + source_user_address: str = Field(default="", max_length=320) + + +class CallWebRtcIceRequest(BaseModel): + call_id: str + candidate: str = Field(min_length=1, max_length=16384) + sdp_mid: str | None = Field(default=None, max_length=128) + sdp_mline_index: int | None = Field(default=None, ge=0) + call_schema_version: int = Field(default=1, ge=1) + call_mode: str = Field(default="webrtc", max_length=32) + max_participants: int = Field(default=2, ge=2, le=8) + target_user_address: str = Field(default="", max_length=320) + source_user_address: str = Field(default="", max_length=320) diff --git a/server/app/schemas/v2_auth.py b/server/app/schemas/v2_auth.py new file mode 100644 index 0000000..b2421fc --- /dev/null +++ b/server/app/schemas/v2_auth.py @@ -0,0 +1,69 @@ +from datetime import datetime + +from pydantic import BaseModel, ConfigDict, Field + +from app.schemas.user import UserOut + + +class RegisterRequestV2(BaseModel): + username: str = Field(min_length=3, max_length=64, pattern=r"^[A-Za-z0-9_]{3,64}$") + password: str = Field(min_length=8, max_length=128) + + +class LoginRequestV2(BaseModel): + username: str = Field(min_length=3, max_length=64, pattern=r"^[A-Za-z0-9_]{3,64}$") + password: str + + +class BindDeviceRequestV2(BaseModel): + device_uid: str + nonce: str = Field(min_length=8, max_length=128) + timestamp_ms: int = Field(ge=0) + proof_signature_b64: str = Field(min_length=16, max_length=1024) + + +class RefreshRequestV2(BaseModel): + refresh_token: str + + +class LogoutRequestV2(BaseModel): + refresh_token: str + + +class BootstrapTokenBundleV2(BaseModel): + bootstrap_token: str + token_type: str = "bootstrap" + bootstrap_expires_in: int + + +class DeviceTokenBundleV2(BaseModel): + access_token: str + token_type: str = "bearer" + access_expires_in: int + refresh_token: str + refresh_expires_in: int + device_uid: str + + +class BootstrapAuthResponseV2(BaseModel): + user: UserOut + tokens: BootstrapTokenBundleV2 + + +class DeviceAuthResponseV2(BaseModel): + user: UserOut + tokens: DeviceTokenBundleV2 + + +class RefreshTokenRecordV2(BaseModel): + model_config = ConfigDict(from_attributes=True) + + id: str + user_id: str + device_id: str | None + token_kind: str + issued_at: datetime + expires_at: datetime + revoked_at: datetime | None + replaced_by: str | None + diff --git a/server/app/schemas/v2_conversation.py b/server/app/schemas/v2_conversation.py new file mode 100644 index 0000000..b62e5ac --- /dev/null +++ b/server/app/schemas/v2_conversation.py @@ -0,0 +1,91 @@ +from datetime import datetime +from typing import Literal + +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from app.schemas.v2_device import DeviceOutV2 +from app.schemas.v2_prekey import ResolvedPrekeyDeviceV2 + + +class CreateDMConversationRequestV2(BaseModel): + peer_username: str | None = Field( + default=None, + min_length=3, + max_length=64, + pattern=r"^[A-Za-z0-9_]{3,64}$", + ) + peer_address: str | None = Field(default=None, min_length=3, max_length=320) + + @model_validator(mode="after") + def validate_peer_target(self) -> "CreateDMConversationRequestV2": + if (self.peer_username is None or not self.peer_username.strip()) and ( + self.peer_address is None or not self.peer_address.strip() + ): + raise ValueError("Either peer_username or peer_address is required") + return self + + +class CreateGroupConversationRequestV2(BaseModel): + name: str = Field(min_length=1, max_length=128) + member_addresses: list[str] = Field(default_factory=list, max_length=64) + + +class GroupInviteRequestV2(BaseModel): + member_addresses: list[str] = Field(min_length=1, max_length=64) + + +class GroupRenameRequestV2(BaseModel): + name: str = Field(min_length=1, max_length=128) + + +class GroupLeaveRequestV2(BaseModel): + reason: str | None = Field(default=None, max_length=64) + + +class ConversationOutV2(BaseModel): + model_config = ConfigDict(from_attributes=True) + + id: str + kind: str = "local" + user_a_id: str = "" + user_b_id: str = "" + local_user_id: str = "" + created_at: datetime + peer_username: str = "" + peer_server_onion: str = "" + peer_address: str = "" + conversation_type: Literal["direct", "group"] = "direct" + group_uid: str | None = None + group_name: str = "" + member_count: int = 0 + membership_state: Literal["none", "invited", "active", "left", "removed"] = "none" + can_manage_members: bool = False + origin_server_onion: str = "" + owner_address: str = "" + + +class ConversationMemberOutV2(BaseModel): + id: str + member_user_id: str | None = None + member_address: str + member_server_onion: str + role: Literal["owner", "member"] = "member" + status: Literal["invited", "active", "left", "removed"] = "invited" + invited_by_address: str = "" + invited_at: datetime + joined_at: datetime | None = None + left_at: datetime | None = None + updated_at: datetime + + +class ConversationRecipientDeviceOutV2(BaseModel): + member_address: str + member_status: Literal["active", "invited", "left", "removed"] = "active" + device: DeviceOutV2 + prekey: ResolvedPrekeyDeviceV2 | None = None + + +class ConversationRecipientsOutV2(BaseModel): + conversation_id: str + conversation_type: Literal["direct", "group"] = "direct" + recipients: list[ConversationRecipientDeviceOutV2] diff --git a/server/app/schemas/v2_device.py b/server/app/schemas/v2_device.py new file mode 100644 index 0000000..986237e --- /dev/null +++ b/server/app/schemas/v2_device.py @@ -0,0 +1,40 @@ +from datetime import datetime + +from typing import Literal + +from pydantic import BaseModel, ConfigDict, Field + + +class DeviceRegisterRequestV2(BaseModel): + label: str = Field(min_length=1, max_length=100) + pub_sign_key: str = Field(min_length=32, max_length=256) + pub_dh_key: str = Field(min_length=32, max_length=256) + + +class DeviceOutV2(BaseModel): + model_config = ConfigDict(from_attributes=True) + + device_uid: str + user_id: str + label: str + pub_sign_key: str + pub_dh_key: str + status: str + supported_message_modes: list[str] = Field(default_factory=list) + created_at: datetime + last_seen_at: datetime + revoked_at: datetime | None + + +class UserDeviceLookupV2(BaseModel): + username: str + peer_address: str + devices: list[DeviceOutV2] + attachment_inline_max_bytes: int = Field(default=0, ge=0) + max_ciphertext_bytes: int = Field(default=0, ge=0) + attachment_policy_source: Literal["local", "remote", "fallback_local"] = "local" + + +class DeviceResolveListV2(BaseModel): + peer_address: str + devices: list[DeviceOutV2] diff --git a/server/app/schemas/v2_federation.py b/server/app/schemas/v2_federation.py new file mode 100644 index 0000000..d034c32 --- /dev/null +++ b/server/app/schemas/v2_federation.py @@ -0,0 +1,186 @@ +from pydantic import BaseModel, Field + +from app.schemas.v2_message import SignedEnvelopeV2 +from app.schemas.v2_prekey import ResolvePrekeysResponseV2 + + +class FederationWellKnownOutV2(BaseModel): + server_onion: str + federation_version: str = "2" + signing_public_key: str + identity_binding_mode: str = "tor_v3_same_ed25519" + attachment_inline_max_bytes: int = Field(default=10485760, ge=1) + max_ciphertext_bytes: int = Field(default=20971520, ge=1) + attachment_hard_ceiling_bytes: int = Field(default=33554432, ge=1) + supported_message_modes: list[str] = Field(default_factory=lambda: ["sealedbox_v0_2a"]) + supported_call_modes: list[str] = Field(default_factory=lambda: ["ws_pcm_v0_2a"]) + + +class FederationUserDevicesOutV2(BaseModel): + username: str + peer_address: str + devices: list[dict] + + +class FederationMessageRelayEnvelopeV2(SignedEnvelopeV2): + recipient_user_id: str = "" + + +class FederationMessageRelayRequestV2(BaseModel): + relay_id: str + conversation_id: str + sender_address: str + sender_device_uid: str = Field(min_length=1, max_length=64) + sender_user_id: str | None = None + encryption_mode: str = Field(default="sealedbox_v0_2a", max_length=32) + client_message_id: str = Field(min_length=8, max_length=128) + sent_at_ms: int = Field(ge=0) + sender_prev_hash: str = Field(max_length=128, default="") + sender_chain_hash: str = Field(min_length=16, max_length=128) + envelopes: list[FederationMessageRelayEnvelopeV2] = Field(min_length=1, max_length=256) + group_uid: str | None = Field(default=None, max_length=64) + + +class FederationUserPrekeysOutV2(ResolvePrekeysResponseV2): + pass + + +class FederationGroupEventRequestV2(BaseModel): + relay_id: str + group_uid: str = Field(min_length=8, max_length=64) + event_seq: int = Field(ge=1) + event_type: str = Field(min_length=3, max_length=24) + actor_address: str = Field(min_length=3, max_length=320) + target_address: str = Field(default="", max_length=320) + payload: dict = Field(default_factory=dict) + created_at: str = Field(min_length=8, max_length=128) + + +class FederationGroupSnapshotOutV2(BaseModel): + group_uid: str + conversation_id: str + group_name: str + origin_server_onion: str + owner_address: str + latest_event_seq: int + members: list[dict] + + +class FederationGroupInviteAcceptRequestV2(BaseModel): + relay_id: str + group_uid: str = Field(min_length=8, max_length=64) + conversation_id: str = Field(min_length=8, max_length=64) + actor_address: str = Field(min_length=3, max_length=320) + + +class FederationCallWebRtcOfferRequestV2(BaseModel): + relay_id: str + call_id: str + from_user_address: str + to_user_address: str + sdp: str = Field(min_length=8, max_length=131072) + call_schema_version: int = Field(default=1, ge=1) + call_mode: str = Field(default="webrtc", max_length=32) + max_participants: int = Field(default=2, ge=2, le=8) + target_user_address: str = Field(default="", max_length=320) + source_user_address: str = Field(default="", max_length=320) + + +class FederationCallWebRtcAnswerRequestV2(BaseModel): + relay_id: str + call_id: str + from_user_address: str + to_user_address: str + sdp: str = Field(min_length=8, max_length=131072) + call_schema_version: int = Field(default=1, ge=1) + call_mode: str = Field(default="webrtc", max_length=32) + max_participants: int = Field(default=2, ge=2, le=8) + target_user_address: str = Field(default="", max_length=320) + source_user_address: str = Field(default="", max_length=320) + + +class FederationCallWebRtcIceRequestV2(BaseModel): + relay_id: str + call_id: str + from_user_address: str + to_user_address: str + candidate: str = Field(min_length=1, max_length=16384) + sdp_mid: str | None = Field(default=None, max_length=128) + sdp_mline_index: int | None = Field(default=None, ge=0) + call_schema_version: int = Field(default=1, ge=1) + call_mode: str = Field(default="webrtc", max_length=32) + max_participants: int = Field(default=2, ge=2, le=8) + target_user_address: str = Field(default="", max_length=320) + source_user_address: str = Field(default="", max_length=320) + + +class FederationGroupCallOfferRequestV2(BaseModel): + relay_id: str + call_id: str + group_uid: str = Field(min_length=8, max_length=64) + conversation_id: str = Field(min_length=8, max_length=64) + from_user_address: str = Field(min_length=3, max_length=320) + call_schema_version: int = Field(default=1, ge=1) + call_mode: str = Field(default="webrtc", max_length=32) + max_participants: int = Field(default=8, ge=2, le=8) + ring_expires_at: str = Field(min_length=8, max_length=128) + + +class FederationGroupCallJoinRequestV2(BaseModel): + relay_id: str + call_id: str + group_uid: str = Field(min_length=8, max_length=64) + member_address: str = Field(min_length=3, max_length=320) + + +class FederationGroupCallLeaveRequestV2(BaseModel): + relay_id: str + call_id: str + group_uid: str = Field(min_length=8, max_length=64) + member_address: str = Field(min_length=3, max_length=320) + reason: str = Field(default="left", max_length=64) + + +class FederationGroupCallEndRequestV2(BaseModel): + relay_id: str + call_id: str + group_uid: str = Field(min_length=8, max_length=64) + reason: str = Field(default="ended", max_length=64) + + +class FederationGroupCallWebRtcOfferRequestV2(BaseModel): + relay_id: str + call_id: str + group_uid: str = Field(min_length=8, max_length=64) + source_user_address: str = Field(min_length=3, max_length=320) + target_user_address: str = Field(min_length=3, max_length=320) + sdp: str = Field(min_length=8, max_length=131072) + call_schema_version: int = Field(default=1, ge=1) + call_mode: str = Field(default="webrtc", max_length=32) + max_participants: int = Field(default=8, ge=2, le=8) + + +class FederationGroupCallWebRtcAnswerRequestV2(BaseModel): + relay_id: str + call_id: str + group_uid: str = Field(min_length=8, max_length=64) + source_user_address: str = Field(min_length=3, max_length=320) + target_user_address: str = Field(min_length=3, max_length=320) + sdp: str = Field(min_length=8, max_length=131072) + call_schema_version: int = Field(default=1, ge=1) + call_mode: str = Field(default="webrtc", max_length=32) + max_participants: int = Field(default=8, ge=2, le=8) + + +class FederationGroupCallWebRtcIceRequestV2(BaseModel): + relay_id: str + call_id: str + group_uid: str = Field(min_length=8, max_length=64) + source_user_address: str = Field(min_length=3, max_length=320) + target_user_address: str = Field(min_length=3, max_length=320) + candidate: str = Field(min_length=1, max_length=16384) + sdp_mid: str | None = Field(default=None, max_length=128) + sdp_mline_index: int | None = Field(default=None, ge=0) + call_schema_version: int = Field(default=1, ge=1) + call_mode: str = Field(default="webrtc", max_length=32) + max_participants: int = Field(default=8, ge=2, le=8) diff --git a/server/app/schemas/v2_message.py b/server/app/schemas/v2_message.py new file mode 100644 index 0000000..35b7227 --- /dev/null +++ b/server/app/schemas/v2_message.py @@ -0,0 +1,98 @@ +import base64 +import binascii +from datetime import datetime + +from pydantic import BaseModel, ConfigDict, Field, field_validator + + +class RatchetHeaderV2(BaseModel): + v: str = Field(default="dr_v1", min_length=3, max_length=16) + dh_pub: str = Field(min_length=16, max_length=256) + n: int = Field(ge=0) + pn: int = Field(ge=0) + + +class RatchetInitV2(BaseModel): + scheme: str = Field(default="x3dh_v1", min_length=4, max_length=32) + sender_ephemeral_pub: str = Field(min_length=16, max_length=256) + signed_prekey_id: int | None = Field(default=None, ge=1) + one_time_prekey_id: int | None = Field(default=None, ge=1) + opk_missing: bool = False + + +class SignedEnvelopeV2(BaseModel): + recipient_user_address: str + recipient_device_uid: str + ciphertext_b64: str + aad_b64: str | None = None + signature_b64: str + sender_device_pubkey: str + ratchet_header: RatchetHeaderV2 | None = None + ratchet_init: RatchetInitV2 | None = None + + @field_validator("ciphertext_b64") + @classmethod + def validate_ciphertext_b64(cls, value: str) -> str: + try: + base64.b64decode(value, validate=True) + except binascii.Error as exc: + raise ValueError("ciphertext_b64 must be valid base64") from exc + return value + + @field_validator("aad_b64") + @classmethod + def validate_aad_b64(cls, value: str | None) -> str | None: + if value is None: + return value + try: + base64.b64decode(value, validate=True) + except binascii.Error as exc: + raise ValueError("aad_b64 must be valid base64") from exc + return value + + +class MessageSendRequestV2(BaseModel): + conversation_id: str + encryption_mode: str = Field(default="sealedbox_v0_2a", max_length=32) + client_message_id: str = Field(min_length=8, max_length=128) + sent_at_ms: int = Field(ge=0) + sender_prev_hash: str = Field(max_length=128, default="") + sender_chain_hash: str = Field(min_length=16, max_length=128) + envelopes: list[SignedEnvelopeV2] = Field(min_length=1, max_length=256) + + +class MessageEventOutV2(BaseModel): + model_config = ConfigDict(from_attributes=True) + + id: str + conversation_id: str + sender_user_id: str | None = None + sender_address: str + sender_device_uid: str + sender_device_pubkey: str + encryption_mode: str = "sealedbox_v0_2a" + client_message_id: str + sent_at_ms: int + sender_prev_hash: str + sender_chain_hash: str + created_at: datetime + + +class MessageDeviceCopyOutV2(BaseModel): + model_config = ConfigDict(from_attributes=True) + + copy_id: str + message: MessageEventOutV2 + recipient_device_uid: str + envelope_json: dict + status: str + created_at: datetime + + +class MessageSendResponseV2(BaseModel): + duplicate: bool + message: MessageEventOutV2 + + +class MessageAckRequestV2(BaseModel): + copy_id: str diff --git a/server/app/schemas/v2_prekey.py b/server/app/schemas/v2_prekey.py new file mode 100644 index 0000000..41e95ed --- /dev/null +++ b/server/app/schemas/v2_prekey.py @@ -0,0 +1,53 @@ +from datetime import datetime + +from pydantic import BaseModel, Field + + +class SignedPrekeyUploadV2(BaseModel): + key_id: int = Field(ge=1) + pub_x25519_b64: str = Field(min_length=16, max_length=128) + sig_by_device_sign_key_b64: str = Field(min_length=16, max_length=256) + expires_at: datetime | None = None + + +class OneTimePrekeyUploadV2(BaseModel): + key_id: int = Field(ge=1) + pub_x25519_b64: str = Field(min_length=16, max_length=128) + + +class PrekeyUploadRequestV2(BaseModel): + signed_prekey: SignedPrekeyUploadV2 + one_time_prekeys: list[OneTimePrekeyUploadV2] = Field(default_factory=list, max_length=512) + + +class SignedPrekeyOutV2(BaseModel): + key_id: int + pub_x25519_b64: str + sig_by_device_sign_key_b64: str + expires_at: datetime + + +class OneTimePrekeyOutV2(BaseModel): + key_id: int + pub_x25519_b64: str + + +class ResolvedPrekeyDeviceV2(BaseModel): + device_uid: str + pub_sign_key: str + pub_dh_key: str + supported_message_modes: list[str] + signed_prekey: SignedPrekeyOutV2 | None = None + one_time_prekey: OneTimePrekeyOutV2 | None = None + opk_missing: bool = False + + +class ResolvePrekeysResponseV2(BaseModel): + username: str + peer_address: str + devices: list[ResolvedPrekeyDeviceV2] + + +class PrekeyUploadResponseV2(BaseModel): + uploaded_signed_prekey_key_id: int + accepted_one_time_prekeys: int diff --git a/server/app/schemas/v2_presence.py b/server/app/schemas/v2_presence.py new file mode 100644 index 0000000..bd5405f --- /dev/null +++ b/server/app/schemas/v2_presence.py @@ -0,0 +1,49 @@ +from pydantic import BaseModel, Field, field_validator + + +_VALID_STATUSES = {"active", "inactive", "offline", "dnd"} + + +class PresenceSetRequest(BaseModel): + status: str = Field(min_length=3, max_length=16) + + @field_validator("status") + @classmethod + def validate_status(cls, value: str) -> str: + normalized = value.strip().lower() + if normalized not in _VALID_STATUSES: + raise ValueError("status must be one of: active, inactive, offline, dnd") + return normalized + + +class PresenceSetResponse(BaseModel): + status: str + + +class PresenceResolveRequest(BaseModel): + peer_addresses: list[str] = Field(default_factory=list, max_length=256) + + @field_validator("peer_addresses") + @classmethod + def normalize_addresses(cls, value: list[str]) -> list[str]: + normalized: list[str] = [] + seen: set[str] = set() + for raw in value: + item = raw.strip().lower() + if not item: + continue + if item in seen: + continue + seen.add(item) + normalized.append(item) + return normalized + + +class PresencePeerOut(BaseModel): + peer_address: str + status: str + + +class PresenceResolveResponse(BaseModel): + self_status: str + peers: list[PresencePeerOut] diff --git a/server/app/security/tokens_v2.py b/server/app/security/tokens_v2.py new file mode 100644 index 0000000..c2f1ce2 --- /dev/null +++ b/server/app/security/tokens_v2.py @@ -0,0 +1,164 @@ +import hashlib +import uuid +from dataclasses import dataclass +from datetime import UTC, datetime, timedelta +from functools import lru_cache +from typing import Any + +import jwt +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import ed25519 + +from app.config import get_settings + + +class TokenErrorV2(ValueError): + pass + + +@dataclass(slots=True) +class DeviceTokenPairV2: + access_token: str + access_expires_in: int + refresh_token: str + refresh_expires_in: int + refresh_token_id: str + refresh_expires_at: datetime + + +def _now_utc() -> datetime: + return datetime.now(UTC) + + +def _as_pem_bytes(value: str) -> bytes: + normalized = value.strip() + if not normalized: + return b"" + return normalized.encode("utf-8") + + +@lru_cache +def _signing_private_key() -> ed25519.Ed25519PrivateKey: + settings = get_settings() + private_pem = _as_pem_bytes(settings.jwt_private_key_pem) + if private_pem: + key = serialization.load_pem_private_key(private_pem, password=None) + if not isinstance(key, ed25519.Ed25519PrivateKey): + raise RuntimeError("BLACKWIRE_JWT_PRIVATE_KEY_PEM must be an Ed25519 key") + return key + + seed = hashlib.sha256(f"{settings.jwt_secret_key}|blackwire-v2-jwt-seed".encode("utf-8")).digest() + return ed25519.Ed25519PrivateKey.from_private_bytes(seed[:32]) + + +@lru_cache +def _verify_public_key() -> ed25519.Ed25519PublicKey: + settings = get_settings() + public_pem = _as_pem_bytes(settings.jwt_public_key_pem) + if public_pem: + key = serialization.load_pem_public_key(public_pem) + if not isinstance(key, ed25519.Ed25519PublicKey): + raise RuntimeError("BLACKWIRE_JWT_PUBLIC_KEY_PEM must be an Ed25519 key") + return key + return _signing_private_key().public_key() + + +def reset_v2_token_cache() -> None: + _signing_private_key.cache_clear() + _verify_public_key.cache_clear() + + +def _encode(payload: dict[str, Any]) -> str: + token = jwt.encode(payload, _signing_private_key(), algorithm="EdDSA") + if isinstance(token, bytes): + return token.decode("utf-8") + return token + + +def create_bootstrap_token(user_id: str, username: str) -> tuple[str, int]: + settings = get_settings() + now = _now_utc() + expiry = now + timedelta(seconds=max(30, settings.v2_bootstrap_token_seconds)) + payload = { + "sub": user_id, + "username": username, + "did": "", + "type": "bootstrap", + "jti": str(uuid.uuid4()), + "iat": int(now.timestamp()), + "exp": int(expiry.timestamp()), + } + ttl = int((expiry - now).total_seconds()) + return _encode(payload), ttl + + +def create_access_token(user_id: str, username: str, device_uid: str) -> tuple[str, int]: + settings = get_settings() + now = _now_utc() + expiry = now + timedelta(minutes=max(1, settings.v2_access_token_minutes)) + payload = { + "sub": user_id, + "username": username, + "did": device_uid, + "type": "access", + "jti": str(uuid.uuid4()), + "iat": int(now.timestamp()), + "exp": int(expiry.timestamp()), + } + return _encode(payload), int((expiry - now).total_seconds()) + + +def create_refresh_token(user_id: str, device_uid: str, token_id: str | None = None) -> tuple[str, str, datetime, int]: + settings = get_settings() + now = _now_utc() + expiry = now + timedelta(days=max(1, settings.v2_refresh_token_days)) + refresh_id = token_id or str(uuid.uuid4()) + payload = { + "sub": user_id, + "username": "", + "did": device_uid, + "type": "refresh", + "rid": refresh_id, + "jti": str(uuid.uuid4()), + "iat": int(now.timestamp()), + "exp": int(expiry.timestamp()), + } + token = _encode(payload) + return token, refresh_id, expiry, int((expiry - now).total_seconds()) + + +def decode_token(token: str, expected_type: str | None = None) -> dict[str, Any]: + try: + payload = jwt.decode( + token, + _verify_public_key(), + algorithms=["EdDSA"], + options={"require": ["exp", "iat", "sub", "did", "type", "jti"]}, + ) + except jwt.ExpiredSignatureError as exc: + raise TokenErrorV2("Token expired") from exc + except jwt.InvalidTokenError as exc: + raise TokenErrorV2("Invalid token") from exc + + token_type = payload.get("type") + if expected_type and token_type != expected_type: + raise TokenErrorV2(f"Expected {expected_type} token") + return payload + + +def build_device_token_pair(user_id: str, username: str, device_uid: str) -> DeviceTokenPairV2: + access_token, access_ttl = create_access_token(user_id=user_id, username=username, device_uid=device_uid) + refresh_token, refresh_id, refresh_expiry, refresh_ttl = create_refresh_token(user_id=user_id, device_uid=device_uid) + return DeviceTokenPairV2( + access_token=access_token, + access_expires_in=access_ttl, + refresh_token=refresh_token, + refresh_expires_in=refresh_ttl, + refresh_token_id=refresh_id, + refresh_expires_at=refresh_expiry, + ) + + +def hash_refresh_token(token: str) -> str: + return hashlib.sha256(token.encode("utf-8")).hexdigest() + diff --git a/server/app/services/auth_service_v2.py b/server/app/services/auth_service_v2.py new file mode 100644 index 0000000..d3745e4 --- /dev/null +++ b/server/app/services/auth_service_v2.py @@ -0,0 +1,227 @@ +import time +from datetime import UTC, datetime + +from nacl import encoding, signing +from nacl import exceptions as nacl_exceptions +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.config import get_settings +from app.models.refresh_token import RefreshToken +from app.models.user import User +from app.schemas.v2_auth import ( + BindDeviceRequestV2, + BootstrapTokenBundleV2, + DeviceTokenBundleV2, +) +from app.schemas.v2_device import DeviceRegisterRequestV2 +from app.security.tokens_v2 import ( + TokenErrorV2, + build_device_token_pair, + create_bootstrap_token, + create_refresh_token, + create_access_token, + decode_token, + hash_refresh_token, +) +from app.services.auth_service import AuthServiceError, auth_service +from app.services.device_service_v2 import device_service_v2 + + +def canonical_bind_device_string( + user_id: str, + device_uid: str, + nonce: str, + timestamp_ms: int, +) -> bytes: + return "\n".join( + [ + "BIND_DEVICE", + user_id.strip(), + device_uid.strip(), + nonce.strip(), + str(timestamp_ms), + ] + ).encode("utf-8") + + +class AuthServiceV2: + def __init__(self) -> None: + self.settings = get_settings() + + async def register(self, session: AsyncSession, username: str, password: str) -> User: + return await auth_service.register(session, username, password) + + async def authenticate(self, session: AsyncSession, username: str, password: str) -> User: + return await auth_service.authenticate(session, username, password) + + async def issue_bootstrap(self, user: User) -> BootstrapTokenBundleV2: + token, ttl = create_bootstrap_token(user.id, user.username) + return BootstrapTokenBundleV2( + bootstrap_token=token, + token_type="bootstrap", + bootstrap_expires_in=ttl, + ) + + async def issue_device_tokens( + self, + session: AsyncSession, + user: User, + device_uid: str, + ) -> DeviceTokenBundleV2: + token_pair = build_device_token_pair(user_id=user.id, username=user.username, device_uid=device_uid) + refresh_record = RefreshToken( + id=token_pair.refresh_token_id, + user_id=user.id, + device_id=device_uid, + token_hash=hash_refresh_token(token_pair.refresh_token), + expires_at=token_pair.refresh_expires_at, + token_kind="refresh", + ) + session.add(refresh_record) + await session.commit() + return DeviceTokenBundleV2( + access_token=token_pair.access_token, + access_expires_in=token_pair.access_expires_in, + refresh_token=token_pair.refresh_token, + refresh_expires_in=token_pair.refresh_expires_in, + device_uid=device_uid, + ) + + async def register_device_from_bootstrap( + self, + session: AsyncSession, + user: User, + label: str, + pub_sign_key: str, + pub_dh_key: str, + ) -> tuple[str, DeviceTokenBundleV2]: + device = await device_service_v2.register_device( + session, + user, + payload=DeviceRegisterRequestV2( + label=label, + pub_sign_key=pub_sign_key, + pub_dh_key=pub_dh_key, + ), + ) + tokens = await self.issue_device_tokens(session, user, device.device_uid) + return device.device_uid, tokens + + async def bind_device( + self, + session: AsyncSession, + user: User, + payload: BindDeviceRequestV2, + ) -> DeviceTokenBundleV2: + now_ms = int(time.time() * 1000) + if abs(now_ms - payload.timestamp_ms) > (self.settings.v2_bind_request_skew_seconds * 1000): + raise AuthServiceError("Bind request timestamp skew exceeded", status_code=401) + + device = await device_service_v2.get_device_for_user(session, user.id, payload.device_uid) + if device is None: + raise AuthServiceError("Device not found", status_code=404) + if device.status != "active" or device.revoked_at is not None: + raise AuthServiceError("Device revoked", status_code=401) + + canonical = canonical_bind_device_string(user.id, payload.device_uid, payload.nonce, payload.timestamp_ms) + try: + verify_key = signing.VerifyKey(device.ik_ed25519_pub.encode("utf-8"), encoder=encoding.Base64Encoder) + signature = encoding.Base64Encoder.decode(payload.proof_signature_b64.encode("utf-8")) + verify_key.verify(canonical, signature) + except (ValueError, nacl_exceptions.BadSignatureError) as exc: + raise AuthServiceError("Device proof signature invalid", status_code=401) from exc + + return await self.issue_device_tokens(session, user, device.id) + + async def refresh(self, session: AsyncSession, refresh_token_raw: str) -> tuple[User, DeviceTokenBundleV2]: + try: + payload = decode_token(refresh_token_raw, expected_type="refresh") + except TokenErrorV2 as exc: + raise AuthServiceError(str(exc), status_code=401) from exc + + user_id = str(payload.get("sub") or "") + refresh_id = str(payload.get("rid") or "") + device_uid = str(payload.get("did") or "") + if not user_id or not refresh_id or not device_uid: + raise AuthServiceError("Invalid refresh token payload", status_code=401) + + stmt = select(RefreshToken).where( + RefreshToken.id == refresh_id, + RefreshToken.user_id == user_id, + RefreshToken.device_id == device_uid, + RefreshToken.token_kind == "refresh", + ) + current = (await session.execute(stmt)).scalar_one_or_none() + if current is None: + raise AuthServiceError("Refresh token not found", status_code=401) + + now = datetime.now(UTC) + if current.revoked_at is not None: + raise AuthServiceError("Refresh token revoked", status_code=401) + if current.expires_at.astimezone(UTC) <= now: + current.revoked_at = now + await session.commit() + raise AuthServiceError("Refresh token expired", status_code=401) + if current.token_hash != hash_refresh_token(refresh_token_raw): + current.revoked_at = now + await session.commit() + raise AuthServiceError("Refresh token mismatch", status_code=401) + + user_stmt = select(User).where(User.id == user_id, User.disabled_at.is_(None)) + user = (await session.execute(user_stmt)).scalar_one_or_none() + if user is None: + raise AuthServiceError("User not found", status_code=401) + + access_token, access_ttl = create_access_token(user.id, user.username, device_uid) + new_refresh_token, new_refresh_id, new_expiry, new_refresh_ttl = create_refresh_token(user.id, device_uid) + + replacement = RefreshToken( + id=new_refresh_id, + user_id=user.id, + device_id=device_uid, + token_hash=hash_refresh_token(new_refresh_token), + expires_at=new_expiry, + token_kind="refresh", + ) + session.add(replacement) + await session.flush() + + current.revoked_at = now + current.replaced_by = replacement.id + await session.commit() + + bundle = DeviceTokenBundleV2( + access_token=access_token, + access_expires_in=access_ttl, + refresh_token=new_refresh_token, + refresh_expires_in=new_refresh_ttl, + device_uid=device_uid, + ) + return user, bundle + + async def logout(self, session: AsyncSession, refresh_token_raw: str) -> None: + try: + payload = decode_token(refresh_token_raw, expected_type="refresh") + except TokenErrorV2: + return + + refresh_id = payload.get("rid") + user_id = payload.get("sub") + if not refresh_id or not user_id: + return + + stmt = select(RefreshToken).where(RefreshToken.id == refresh_id, RefreshToken.user_id == user_id) + token = (await session.execute(stmt)).scalar_one_or_none() + if token is None: + return + + hashed = hash_refresh_token(refresh_token_raw) + if token.token_hash != hashed: + return + + token.revoked_at = datetime.now(UTC) + await session.commit() + + +auth_service_v2 = AuthServiceV2() diff --git a/server/app/services/call_service.py b/server/app/services/call_service.py index add4ab5..885c0d7 100644 --- a/server/app/services/call_service.py +++ b/server/app/services/call_service.py @@ -2,6 +2,7 @@ import asyncio import base64 import binascii +import json import time from dataclasses import dataclass, field from typing import Literal @@ -19,6 +20,9 @@ CallEndRequest, CallOfferRequest, CallRejectRequest, + CallWebRtcAnswerRequest, + CallWebRtcIceRequest, + CallWebRtcOfferRequest, ) from app.schemas.federation import ( FederationCallAcceptRequest, @@ -27,6 +31,11 @@ FederationCallOfferRequest, FederationCallRejectRequest, ) +from app.schemas.v2_federation import ( + FederationCallWebRtcAnswerRequestV2, + FederationCallWebRtcIceRequestV2, + FederationCallWebRtcOfferRequestV2, +) from app.services.conversation_service import conversation_service from app.services.federation_client import FederationClientError, federation_client from app.services.peer_address import parse_peer_address_with_policy @@ -86,6 +95,24 @@ def __init__(self) -> None: self._user_to_call: dict[str, str] = {} self._lock = asyncio.Lock() + def _webrtc_metadata(self) -> dict: + if not self.settings.enable_webrtc_v2b2: + return {} + metadata: dict = { + "call_schema_version": 1, + "call_mode": "webrtc", + "max_participants": 2, + } + try: + parsed = json.loads(self.settings.webrtc_ice_servers_json) if self.settings.webrtc_ice_servers_json else [] + except json.JSONDecodeError: + parsed = [] + if isinstance(parsed, list): + metadata["ice_servers"] = parsed + else: + metadata["ice_servers"] = [] + return metadata + async def _local_address_for_user_id(self, session: AsyncSession, user_id: str) -> str: stmt = select(User).where(User.id == user_id) user = (await session.execute(stmt)).scalar_one_or_none() @@ -131,7 +158,6 @@ async def offer( incoming_payload = None ringing_payload = None busy_payload = None - rejected_payload = None async with self._lock: if caller_user_id in self._user_to_call: @@ -146,12 +172,6 @@ async def offer( "reason": "peer_busy", "conversation_id": payload.conversation_id, } - elif not await connection_manager.has_user(callee_user_id): - rejected_payload = { - "type": "call.rejected", - "reason": "peer_offline", - "conversation_id": payload.conversation_id, - } else: call_id = str(uuid4()) call = CallSession( @@ -186,9 +206,6 @@ async def offer( if busy_payload is not None: await connection_manager.send_to_user(caller_user_id, busy_payload) return - if rejected_payload is not None: - await connection_manager.send_to_user(caller_user_id, rejected_payload) - return if incoming_payload is not None: await connection_manager.send_to_user(callee_user_id, incoming_payload) if ringing_payload is not None: @@ -297,6 +314,7 @@ async def accept(self, user_id: str, payload: CallAcceptRequest) -> None: "conversation_id": call.conversation_id, "peer_user_id": call.callee_user_id or "", "peer_user_address": call.callee_address, + **self._webrtc_metadata(), }, ), ( @@ -307,6 +325,7 @@ async def accept(self, user_id: str, payload: CallAcceptRequest) -> None: "conversation_id": call.conversation_id, "peer_user_id": call.caller_user_id or "", "peer_user_address": call.caller_address, + **self._webrtc_metadata(), }, ), ] @@ -330,6 +349,7 @@ async def accept(self, user_id: str, payload: CallAcceptRequest) -> None: "conversation_id": call.conversation_id, "peer_user_id": "", "peer_user_address": call.caller_address, + **self._webrtc_metadata(), }, ) ] @@ -492,6 +512,8 @@ async def end(self, user_id: str, payload: CallEndRequest) -> None: pass async def audio(self, user_id: str, payload: CallAudioRequest) -> None: + if not self.settings.enable_legacy_call_audio_ws: + raise CallProtocolError("audio_deprecated", "WS audio transport is disabled; use WebRTC") try: pcm = base64.b64decode(payload.pcm_b64.encode("utf-8"), validate=True) except binascii.Error as exc: @@ -562,10 +584,215 @@ async def audio(self, user_id: str, payload: CallAudioRequest) -> None: ) except FederationClientError as exc: raise CallProtocolError("federation_audio_failed", exc.detail) from exc + + def _validate_webrtc_enabled(self) -> None: + if not self.settings.enable_webrtc_v2b2: + raise CallProtocolError("webrtc_disabled", "WebRTC signaling is disabled") + + @staticmethod + def _validate_call_mode(call_mode: str, max_participants: int) -> None: + normalized_mode = (call_mode or "").strip().lower() + if normalized_mode != "webrtc": + raise CallProtocolError("invalid_call_mode", "call_mode must be 'webrtc'") + if max_participants != 2: + raise CallProtocolError("invalid_participant_limit", "Only 1:1 calls are supported in this release") + + async def webrtc_offer(self, user_id: str, payload: CallWebRtcOfferRequest) -> None: + self._validate_webrtc_enabled() + self._validate_call_mode(payload.call_mode, payload.max_participants) + await self._relay_webrtc_signaling( + user_id=user_id, + call_id=payload.call_id, + event_type="call.webrtc.offer", + local_payload={ + "type": "call.webrtc.offer", + "call_id": payload.call_id, + "sdp": payload.sdp, + "call_schema_version": payload.call_schema_version, + "call_mode": payload.call_mode, + "max_participants": payload.max_participants, + }, + federation_path="/api/v2/federation/calls/webrtc-offer", + federation_payload={ + "call_id": payload.call_id, + "sdp": payload.sdp, + "call_schema_version": payload.call_schema_version, + "call_mode": payload.call_mode, + "max_participants": payload.max_participants, + }, + ) + + async def webrtc_answer(self, user_id: str, payload: CallWebRtcAnswerRequest) -> None: + self._validate_webrtc_enabled() + self._validate_call_mode(payload.call_mode, payload.max_participants) + await self._relay_webrtc_signaling( + user_id=user_id, + call_id=payload.call_id, + event_type="call.webrtc.answer", + local_payload={ + "type": "call.webrtc.answer", + "call_id": payload.call_id, + "sdp": payload.sdp, + "call_schema_version": payload.call_schema_version, + "call_mode": payload.call_mode, + "max_participants": payload.max_participants, + }, + federation_path="/api/v2/federation/calls/webrtc-answer", + federation_payload={ + "call_id": payload.call_id, + "sdp": payload.sdp, + "call_schema_version": payload.call_schema_version, + "call_mode": payload.call_mode, + "max_participants": payload.max_participants, + }, + ) + + async def webrtc_ice(self, user_id: str, payload: CallWebRtcIceRequest) -> None: + self._validate_webrtc_enabled() + self._validate_call_mode(payload.call_mode, payload.max_participants) + await self._relay_webrtc_signaling( + user_id=user_id, + call_id=payload.call_id, + event_type="call.webrtc.ice", + local_payload={ + "type": "call.webrtc.ice", + "call_id": payload.call_id, + "candidate": payload.candidate, + "sdp_mid": payload.sdp_mid, + "sdp_mline_index": payload.sdp_mline_index, + "call_schema_version": payload.call_schema_version, + "call_mode": payload.call_mode, + "max_participants": payload.max_participants, + }, + federation_path="/api/v2/federation/calls/webrtc-ice", + federation_payload={ + "call_id": payload.call_id, + "candidate": payload.candidate, + "sdp_mid": payload.sdp_mid, + "sdp_mline_index": payload.sdp_mline_index, + "call_schema_version": payload.call_schema_version, + "call_mode": payload.call_mode, + "max_participants": payload.max_participants, + }, + ) + + async def _relay_webrtc_signaling( + self, + *, + user_id: str, + call_id: str, + event_type: str, + local_payload: dict, + federation_path: str, + federation_payload: dict, + ) -> None: + local_target: str | None = None + relay_data: tuple[str, str, str] | None = None + + async with self._lock: + call = self._calls.get(call_id) + if call is None: + raise CallProtocolError("call_not_found", "Call session not found") + if call.state != "active": + raise CallProtocolError("call_not_active", "Call is not active") + if not call.has_participant(user_id): + raise CallProtocolError("forbidden", "Not a participant in this call") + + if call.direction == "local": + local_target = call.peer_of(user_id) + elif call.direction == "federated_outbound": + parsed_peer = parse_peer_address_with_policy(call.callee_address, self.settings.tor_enabled) + relay_data = (parsed_peer.server_onion, call.caller_address, call.callee_address) + else: + parsed_peer = parse_peer_address_with_policy(call.caller_address, self.settings.tor_enabled) + relay_data = (parsed_peer.server_onion, call.callee_address, call.caller_address) + + if local_target is not None: + await connection_manager.send_to_user(local_target, local_payload) + return + + if relay_data is not None: + peer_onion, from_user_address, to_user_address = relay_data + payload_json = { + "relay_id": str(uuid4()), + "from_user_address": from_user_address, + "to_user_address": to_user_address, + } + payload_json.update(federation_payload) + try: + await federation_client.post_signed(peer_onion, federation_path, payload_json) + except FederationClientError as exc: + raise CallProtocolError(f"{event_type}.relay_failed", exc.detail) from exc + + async def relay_webrtc_offer(self, payload: FederationCallWebRtcOfferRequestV2) -> None: + self._validate_webrtc_enabled() + self._validate_call_mode(payload.call_mode, payload.max_participants) + await self._relay_webrtc_event_to_local( + call_id=payload.call_id, + payload={ + "type": "call.webrtc.offer", + "call_id": payload.call_id, + "sdp": payload.sdp, + "call_schema_version": payload.call_schema_version, + "call_mode": payload.call_mode, + "max_participants": payload.max_participants, + "from_user_address": payload.from_user_address, + }, + ) + + async def relay_webrtc_answer(self, payload: FederationCallWebRtcAnswerRequestV2) -> None: + self._validate_webrtc_enabled() + self._validate_call_mode(payload.call_mode, payload.max_participants) + await self._relay_webrtc_event_to_local( + call_id=payload.call_id, + payload={ + "type": "call.webrtc.answer", + "call_id": payload.call_id, + "sdp": payload.sdp, + "call_schema_version": payload.call_schema_version, + "call_mode": payload.call_mode, + "max_participants": payload.max_participants, + "from_user_address": payload.from_user_address, + }, + ) + + async def relay_webrtc_ice(self, payload: FederationCallWebRtcIceRequestV2) -> None: + self._validate_webrtc_enabled() + self._validate_call_mode(payload.call_mode, payload.max_participants) + await self._relay_webrtc_event_to_local( + call_id=payload.call_id, + payload={ + "type": "call.webrtc.ice", + "call_id": payload.call_id, + "candidate": payload.candidate, + "sdp_mid": payload.sdp_mid, + "sdp_mline_index": payload.sdp_mline_index, + "call_schema_version": payload.call_schema_version, + "call_mode": payload.call_mode, + "max_participants": payload.max_participants, + "from_user_address": payload.from_user_address, + }, + ) + + async def _relay_webrtc_event_to_local(self, call_id: str, payload: dict) -> None: + local_target: str | None = None + async with self._lock: + call = self._calls.get(call_id) + if call is None: + raise CallProtocolError("call_not_found", "Call session not found") + if call.state != "active": + raise CallProtocolError("call_not_active", "Call is not active") + if call.direction == "federated_outbound": + local_target = call.caller_user_id + elif call.direction == "federated_inbound": + local_target = call.callee_user_id + else: + raise CallProtocolError("invalid_state", "Local call cannot consume federated WebRTC relay") + if local_target: + await connection_manager.send_to_user(local_target, payload) + async def relay_offer(self, session: AsyncSession, payload: FederationCallOfferRequest) -> None: local_user = await self._local_user_for_address(session, payload.to_user_address) - if not await connection_manager.has_user(local_user.id): - raise CallProtocolError("peer_offline", "Local callee is offline") async with self._lock: if local_user.id in self._user_to_call: @@ -637,6 +864,7 @@ async def relay_accept(self, payload: FederationCallAcceptRequest) -> None: "conversation_id": conversation_id, "peer_user_id": "", "peer_user_address": peer_address, + **self._webrtc_metadata(), }, ) @@ -687,6 +915,8 @@ async def relay_end(self, payload: FederationCallEndRequest) -> None: ) async def relay_audio(self, payload: FederationCallAudioRequest) -> None: + if not self.settings.enable_legacy_call_audio_ws: + raise CallProtocolError("audio_deprecated", "WS audio transport is disabled; use WebRTC") local_target: str | None = None try: pcm = base64.b64decode(payload.pcm_b64.encode("utf-8"), validate=True) diff --git a/server/app/services/conversation_service.py b/server/app/services/conversation_service.py index f3e7233..7c0f120 100644 --- a/server/app/services/conversation_service.py +++ b/server/app/services/conversation_service.py @@ -4,8 +4,10 @@ from app.config import get_settings from app.models.conversation import Conversation +from app.models.conversation_member import ConversationMember from app.models.user import User from app.services.peer_address import parse_peer_address, parse_peer_address_with_policy +from app.services.server_authority import is_local_server_authority from app.services.server_identity import get_server_onion, server_address_for_username @@ -44,6 +46,7 @@ async def _create_local_dm(self, session: AsyncSession, user: User, peer_usernam conversation = Conversation( kind="local", + conversation_type="direct", user_a_id=user_a, user_b_id=user_b, peer_server_onion=get_server_onion(), @@ -58,13 +61,14 @@ async def _create_remote_dm( session: AsyncSession, user: User, peer_address: str, + request_authority: str | None = None, ) -> Conversation: try: parsed = parse_peer_address_with_policy(peer_address, self.settings.tor_enabled) except ValueError as exc: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc - local_onion = get_server_onion() - if parsed.server_onion == local_onion: + additional_aliases = {request_authority} if request_authority else None + if is_local_server_authority(parsed.server_onion, self.settings, additional_aliases): return await self._create_local_dm(session, user, parsed.username) existing_stmt = select(Conversation).where( @@ -78,6 +82,7 @@ async def _create_remote_dm( conversation = Conversation( kind="remote", + conversation_type="direct", local_user_id=user.id, peer_username=parsed.username, peer_server_onion=parsed.server_onion, @@ -94,9 +99,10 @@ async def create_dm( user: User, peer_username: str | None, peer_address: str | None, + request_authority: str | None = None, ) -> Conversation: if peer_address is not None and peer_address.strip(): - return await self._create_remote_dm(session, user, peer_address) + return await self._create_remote_dm(session, user, peer_address, request_authority=request_authority) if peer_username is not None and peer_username.strip(): return await self._create_local_dm(session, user, peer_username) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="peer target is required") @@ -117,6 +123,15 @@ async def list_for_user( or_(Conversation.user_a_id == user.id, Conversation.user_b_id == user.id), ), and_(Conversation.kind == "remote", Conversation.local_user_id == user.id), + and_( + Conversation.conversation_type == "group", + Conversation.id.in_( + select(ConversationMember.conversation_id).where( + ConversationMember.member_user_id == user.id, + ConversationMember.status.in_(["active", "invited"]), + ) + ), + ), ) ) .order_by(Conversation.created_at.desc()) @@ -133,6 +148,8 @@ async def peer_username_for_user( ) -> str: if conversation.kind == "remote": return conversation.peer_username or "" + if conversation.conversation_type == "group": + return conversation.group_name or "" peer_user_id = self.peer_id(conversation, user_id) stmt = select(User.username).where(User.id == peer_user_id) @@ -147,6 +164,8 @@ async def peer_address_for_user( ) -> str: if conversation.kind == "remote": return conversation.peer_address or "" + if conversation.conversation_type == "group": + return conversation.owner_address or "" username = await self.peer_username_for_user(session, conversation, user_id) if not username: return "" @@ -160,6 +179,8 @@ async def peer_server_onion_for_user( ) -> str: if conversation.kind == "remote": return conversation.peer_server_onion or "" + if conversation.conversation_type == "group": + return conversation.origin_server_onion or "" peer_address = await self.peer_address_for_user(session, conversation, user_id) try: @@ -168,6 +189,11 @@ async def peer_server_onion_for_user( return "" def ensure_membership(self, conversation: Conversation, user_id: str) -> None: + if conversation.conversation_type == "group": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Group membership must be checked via group service", + ) if conversation.kind == "remote": if conversation.local_user_id != user_id: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not a conversation member") @@ -176,6 +202,8 @@ def ensure_membership(self, conversation: Conversation, user_id: str) -> None: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not a conversation member") def peer_id(self, conversation: Conversation, user_id: str) -> str: + if conversation.conversation_type == "group": + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Group conversation has multiple peers") if conversation.kind == "remote": raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Remote conversation has no local peer") if conversation.user_a_id == user_id: @@ -215,6 +243,7 @@ async def get_or_create_remote_for_local_user( conversation = Conversation( kind="remote", + conversation_type="direct", local_user_id=local_user.id, peer_username=parsed.username, peer_server_onion=parsed.server_onion, diff --git a/server/app/services/device_service.py b/server/app/services/device_service.py index 7575a11..e4663db 100644 --- a/server/app/services/device_service.py +++ b/server/app/services/device_service.py @@ -11,7 +11,7 @@ from app.schemas.device import UserDeviceLookup from app.services.federation_client import FederationClientError, federation_client from app.services.peer_address import parse_peer_address_with_policy -from app.services.server_identity import get_server_onion +from app.services.server_authority import is_local_server_authority class DeviceService: @@ -34,6 +34,7 @@ async def register_device( label=payload.label, ik_ed25519_pub=payload.ik_ed25519_pub, enc_x25519_pub=payload.enc_x25519_pub, + status="active", ) session.add(device) await session.flush() @@ -45,6 +46,7 @@ async def register_device( old_device = (await session.execute(old_device_stmt)).scalar_one_or_none() if old_device is not None: old_device.revoked_at = datetime.now(UTC) + old_device.status = "revoked" current_active.device_id = device.id current_active.updated_at = datetime.now(UTC) else: @@ -58,7 +60,11 @@ async def get_active_device_for_user(self, session: AsyncSession, user_id: str) stmt = ( select(Device) .join(ActiveDevice, ActiveDevice.device_id == Device.id) - .where(ActiveDevice.user_id == user_id, Device.revoked_at.is_(None)) + .where( + ActiveDevice.user_id == user_id, + Device.revoked_at.is_(None), + Device.status == "active", + ) ) return (await session.execute(stmt)).scalar_one_or_none() @@ -78,12 +84,14 @@ async def resolve_device_by_peer_address( self, session: AsyncSession, peer_address: str, + request_authority: str | None = None, ) -> UserDeviceLookup | None: try: parsed = parse_peer_address_with_policy(peer_address, self.settings.tor_enabled) except ValueError as exc: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc - if parsed.server_onion == get_server_onion(): + additional_aliases = {request_authority} if request_authority else None + if is_local_server_authority(parsed.server_onion, self.settings, additional_aliases): local_result = await self.get_active_device_by_username(session, parsed.username) if local_result is None: return None diff --git a/server/app/services/device_service_v2.py b/server/app/services/device_service_v2.py new file mode 100644 index 0000000..6a21ff0 --- /dev/null +++ b/server/app/services/device_service_v2.py @@ -0,0 +1,175 @@ +from datetime import UTC, datetime + +from fastapi import HTTPException, status +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.config import get_settings +from app.models.device import Device +from app.models.refresh_token import RefreshToken +from app.models.user import User +from app.schemas.v2_device import DeviceOutV2, DeviceRegisterRequestV2, UserDeviceLookupV2 +from app.services.federation_client import FederationClientError, federation_client +from app.services.metrics import metrics +from app.services.peer_address import parse_peer_address_with_policy +from app.services.server_authority import is_local_server_authority +from app.ws.manager import connection_manager + + +class DeviceServiceV2: + def __init__(self) -> None: + self.settings = get_settings() + + def supported_message_modes(self) -> list[str]: + modes = ["sealedbox_v0_2a"] + if self.settings.enable_ratchet_v2b1: + modes.append("ratchet_v0_2b1") + return modes + + def _to_device_out(self, device: Device) -> DeviceOutV2: + return DeviceOutV2( + device_uid=device.id, + user_id=device.user_id, + label=device.label, + pub_sign_key=device.ik_ed25519_pub, + pub_dh_key=device.enc_x25519_pub, + status=device.status, + supported_message_modes=self.supported_message_modes(), + created_at=device.created_at, + last_seen_at=device.last_seen_at, + revoked_at=device.revoked_at, + ) + + def _local_attachment_inline_max_bytes(self) -> int: + return self.settings.effective_attachment_inline_max_bytes() + + def _local_max_ciphertext_bytes(self) -> int: + return self.settings.effective_max_ciphertext_bytes() + + async def register_device( + self, + session: AsyncSession, + user: User, + payload: DeviceRegisterRequestV2, + ) -> DeviceOutV2: + existing_key_stmt = select(Device).where(Device.enc_x25519_pub == payload.pub_dh_key) + existing_key = (await session.execute(existing_key_stmt)).scalar_one_or_none() + if existing_key is not None: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Device key already registered") + + device = Device( + user_id=user.id, + label=payload.label, + ik_ed25519_pub=payload.pub_sign_key, + enc_x25519_pub=payload.pub_dh_key, + status="active", + last_seen_at=datetime.now(UTC), + revoked_at=None, + ) + session.add(device) + await session.commit() + await session.refresh(device) + return self._to_device_out(device) + + async def list_for_user(self, session: AsyncSession, user_id: str) -> list[DeviceOutV2]: + stmt = select(Device).where(Device.user_id == user_id).order_by(Device.created_at.asc()) + rows = list((await session.execute(stmt)).scalars().all()) + return [self._to_device_out(row) for row in rows] + + async def list_active_for_user(self, session: AsyncSession, user_id: str) -> list[Device]: + stmt = ( + select(Device) + .where( + Device.user_id == user_id, + Device.status == "active", + Device.revoked_at.is_(None), + ) + .order_by(Device.created_at.asc()) + ) + return list((await session.execute(stmt)).scalars().all()) + + async def get_device_for_user(self, session: AsyncSession, user_id: str, device_uid: str) -> Device | None: + stmt = select(Device).where(Device.id == device_uid, Device.user_id == user_id) + return (await session.execute(stmt)).scalar_one_or_none() + + async def revoke_device(self, session: AsyncSession, user: User, device_uid: str) -> DeviceOutV2: + device = await self.get_device_for_user(session, user.id, device_uid) + if device is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Device not found") + if device.status == "revoked": + return self._to_device_out(device) + + now = datetime.now(UTC) + device.status = "revoked" + device.revoked_at = now + device.last_seen_at = now + + refresh_tokens_stmt = select(RefreshToken).where( + RefreshToken.user_id == user.id, + RefreshToken.device_id == device.id, + RefreshToken.revoked_at.is_(None), + RefreshToken.token_kind == "refresh", + ) + refresh_tokens = list((await session.execute(refresh_tokens_stmt)).scalars().all()) + for refresh in refresh_tokens: + refresh.revoked_at = now + + await session.commit() + await connection_manager.disconnect_device(device.id) + return self._to_device_out(device) + + async def resolve_devices_by_peer_address( + self, + session: AsyncSession, + peer_address: str, + request_authority: str | None = None, + ) -> UserDeviceLookupV2 | None: + try: + parsed = parse_peer_address_with_policy(peer_address, self.settings.tor_enabled) + except ValueError as exc: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc + + additional_aliases = {request_authority} if request_authority else None + if is_local_server_authority(parsed.server_onion, self.settings, additional_aliases): + user_stmt = select(User).where(User.username == parsed.username, User.disabled_at.is_(None)) + user = (await session.execute(user_stmt)).scalar_one_or_none() + if user is None: + return None + devices = await self.list_active_for_user(session, user.id) + return UserDeviceLookupV2( + username=user.username, + peer_address=parsed.canonical, + devices=[self._to_device_out(device) for device in devices], + attachment_inline_max_bytes=self._local_attachment_inline_max_bytes(), + max_ciphertext_bytes=self._local_max_ciphertext_bytes(), + attachment_policy_source="local", + ) + + try: + remote = await federation_client.get_remote_user_devices_v2(parsed.server_onion, parsed.username) + except FederationClientError as exc: + if exc.status_code == 404: + return None + raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=exc.detail) from exc + + if not remote.peer_address: + remote.peer_address = parsed.canonical + remote.attachment_policy_source = "remote" + if remote.attachment_inline_max_bytes <= 0 or remote.max_ciphertext_bytes <= 0: + remote.attachment_inline_max_bytes = self._local_attachment_inline_max_bytes() + remote.max_ciphertext_bytes = self._local_max_ciphertext_bytes() + remote.attachment_policy_source = "fallback_local" + await metrics.inc("attachments.policy.fallback_local") + else: + remote.attachment_inline_max_bytes = min( + remote.attachment_inline_max_bytes, + self.settings.attachment_hard_ceiling_bytes, + ) + remote.max_ciphertext_bytes = min( + remote.max_ciphertext_bytes, + self.settings.attachment_hard_ceiling_bytes, + ) + return remote + + +device_service_v2 = DeviceServiceV2() diff --git a/server/app/services/federation_client.py b/server/app/services/federation_client.py index a21b390..5f03e13 100644 --- a/server/app/services/federation_client.py +++ b/server/app/services/federation_client.py @@ -2,9 +2,13 @@ import httpx import orjson +from pydantic import ValidationError from app.config import get_settings from app.schemas.device import UserDeviceLookup +from app.schemas.v2_device import UserDeviceLookupV2 +from app.schemas.v2_federation import FederationGroupSnapshotOutV2 +from app.schemas.v2_prekey import ResolvePrekeysResponseV2 from app.services.federation_security import federation_security_service from app.services.peer_address import is_onion_authority @@ -49,6 +53,87 @@ async def get_remote_user_device(self, peer_onion: str, username: str) -> UserDe raise FederationClientError(response.status_code, response.text or "Remote device lookup failed") return UserDeviceLookup.model_validate(response.json()) + async def get_remote_user_devices_v2(self, peer_onion: str, username: str) -> UserDeviceLookupV2: + peer = peer_onion.strip().lower() + normalized_username = username.strip().lower() + if not is_onion_authority(peer): + raise FederationClientError(400, "Invalid remote server onion authority") + if not normalized_username: + raise FederationClientError(400, "Invalid remote username") + + async with await self._build_http_client() as client: + try: + response = await client.get( + "http://" + f"{peer}/api/v2/federation/users/{quote(normalized_username, safe='')}/devices" + ) + except httpx.HTTPError as exc: + raise FederationClientError( + 502, + f"Remote device lookup failed for {normalized_username}@{peer}", + ) from exc + + if response.status_code >= 400: + raise FederationClientError(response.status_code, response.text or "Remote device lookup failed") + payload = response.json() + try: + return UserDeviceLookupV2.model_validate(payload) + except ValidationError: + fallback_payload = { + "username": payload.get("username", normalized_username), + "peer_address": payload.get("peer_address", ""), + "devices": payload.get("devices", []), + "attachment_inline_max_bytes": 0, + "max_ciphertext_bytes": 0, + "attachment_policy_source": "fallback_local", + } + return UserDeviceLookupV2.model_validate(fallback_payload) + + async def get_remote_user_prekeys_v2(self, peer_onion: str, username: str) -> ResolvePrekeysResponseV2: + peer = peer_onion.strip().lower() + normalized_username = username.strip().lower() + if not is_onion_authority(peer): + raise FederationClientError(400, "Invalid remote server onion authority") + if not normalized_username: + raise FederationClientError(400, "Invalid remote username") + + async with await self._build_http_client() as client: + try: + response = await client.get( + "http://" + f"{peer}/api/v2/federation/users/{quote(normalized_username, safe='')}/prekeys" + ) + except httpx.HTTPError as exc: + raise FederationClientError( + 502, + f"Remote prekey lookup failed for {normalized_username}@{peer}", + ) from exc + + if response.status_code >= 400: + raise FederationClientError(response.status_code, response.text or "Remote prekey lookup failed") + return ResolvePrekeysResponseV2.model_validate(response.json()) + + async def get_remote_group_snapshot_v2(self, peer_onion: str, group_uid: str) -> FederationGroupSnapshotOutV2: + peer = peer_onion.strip().lower() + normalized_group_uid = group_uid.strip().lower() + if not is_onion_authority(peer): + raise FederationClientError(400, "Invalid remote server onion authority") + if not normalized_group_uid: + raise FederationClientError(400, "Invalid group uid") + + async with await self._build_http_client() as client: + try: + response = await client.get( + "http://" + f"{peer}/api/v2/federation/groups/{quote(normalized_group_uid, safe='')}/snapshot" + ) + except httpx.HTTPError as exc: + raise FederationClientError(502, f"Remote group snapshot lookup failed for {normalized_group_uid}") from exc + + if response.status_code >= 400: + raise FederationClientError(response.status_code, response.text or "Remote group snapshot failed") + return FederationGroupSnapshotOutV2.model_validate(response.json()) + async def post_signed( self, peer_onion: str, diff --git a/server/app/services/federation_security.py b/server/app/services/federation_security.py index c6dcbfe..3a525b9 100644 --- a/server/app/services/federation_security.py +++ b/server/app/services/federation_security.py @@ -26,6 +26,21 @@ _NONCE_PATTERN = re.compile(r"^[a-zA-Z0-9._:-]{8,128}$") +def onion_v3_public_key_from_authority(onion_authority: str) -> bytes | None: + normalized = onion_authority.strip().lower() + if normalized.endswith(".onion"): + normalized = normalized[:-6] + if len(normalized) != 56: + return None + try: + decoded = base64.b32decode(normalized.upper()) + except Exception: + return None + if len(decoded) != 35: + return None + return decoded[:32] + + def canonical_request_string( method: str, path: str, @@ -78,6 +93,20 @@ async def fetch_well_known(self, peer_onion: str) -> FederationWellKnownOut: status_code=status.HTTP_502_BAD_GATEWAY, detail="Peer well-known server identity mismatch", ) + onion_pub = onion_v3_public_key_from_authority(normalized) + if onion_pub is not None: + try: + discovered_key = base64.b64decode(payload.signing_public_key.encode("utf-8"), validate=True) + except Exception as exc: + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail="Peer well-known signing key is invalid base64", + ) from exc + if discovered_key != onion_pub: + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail="Peer signing key does not match onion identity", + ) return payload async def get_or_onboard_peer(self, session: AsyncSession, peer_onion: str) -> FederationPeer: @@ -99,6 +128,12 @@ async def get_or_onboard_peer(self, session: AsyncSession, peer_onion: str) -> F if peer.status == "key_conflict": raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Peer key conflict") + discovered = await self.fetch_well_known(normalized) + if discovered.signing_public_key != peer.signing_public_key: + peer.status = "key_conflict" + await session.commit() + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Peer key conflict") + peer.last_seen_at = datetime.now(UTC) await session.flush() return peer diff --git a/server/app/services/group_call_service.py b/server/app/services/group_call_service.py new file mode 100644 index 0000000..deccb21 --- /dev/null +++ b/server/app/services/group_call_service.py @@ -0,0 +1,863 @@ +from __future__ import annotations + +import asyncio +import base64 +import binascii +import time +from datetime import UTC, datetime, timedelta +from uuid import uuid4 + +from fastapi import HTTPException, status +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.config import get_settings +from app.models.conversation import Conversation +from app.models.conversation_member import ConversationMember +from app.models.group_call_participant import GroupCallParticipant +from app.models.group_call_session import GroupCallSession +from app.models.user import User +from app.schemas.call import CallAudioRequest +from app.schemas.v2_federation import ( + FederationGroupCallEndRequestV2, + FederationGroupCallJoinRequestV2, + FederationGroupCallLeaveRequestV2, + FederationGroupCallOfferRequestV2, + FederationGroupCallWebRtcAnswerRequestV2, + FederationGroupCallWebRtcIceRequestV2, + FederationGroupCallWebRtcOfferRequestV2, +) +from app.services.federation_client import FederationClientError, federation_client +from app.services.metrics import metrics +from app.services.peer_address import parse_peer_address_with_policy +from app.services.server_authority import is_local_server_authority +from app.services.server_identity import get_server_onion, server_address_for_username +from app.ws.manager import connection_manager + + +class GroupCallService: + def __init__(self) -> None: + self.settings = get_settings() + self._audio_last_at: dict[tuple[str, str], float] = {} + + @staticmethod + def _delivery_states() -> set[str]: + return {"joined", "ringing", "offline_pending"} + + def _clear_audio_state_for_call(self, call_id: str) -> None: + stale_keys = [key for key in self._audio_last_at if key[0] == call_id] + for key in stale_keys: + self._audio_last_at.pop(key, None) + + async def _require_enabled(self) -> None: + if not self.settings.enable_group_call_v2c: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Group calls are disabled") + + async def _user(self, session: AsyncSession, user_id: str) -> User: + stmt = select(User).where(User.id == user_id, User.disabled_at.is_(None)) + user = (await session.execute(stmt)).scalar_one_or_none() + if user is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") + return user + + async def _call(self, session: AsyncSession, call_id: str) -> GroupCallSession: + stmt = select(GroupCallSession).where(GroupCallSession.call_id == call_id) + row = (await session.execute(stmt)).scalar_one_or_none() + if row is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Call not found") + return row + + async def _participants(self, session: AsyncSession, call_id: str) -> list[GroupCallParticipant]: + stmt = select(GroupCallParticipant).where(GroupCallParticipant.call_id == call_id) + return list((await session.execute(stmt)).scalars().all()) + + async def is_group_call(self, session: AsyncSession, call_id: str) -> bool: + stmt = select(GroupCallSession.call_id).where(GroupCallSession.call_id == call_id).limit(1) + return (await session.execute(stmt)).scalar_one_or_none() is not None + + @staticmethod + def _participant_summary(rows: list[GroupCallParticipant]) -> list[dict]: + return [{"member_address": row.member_address, "state": row.state} for row in rows] + + async def _post_federation( + self, + *, + peer_onion: str, + endpoint_path: str, + payload_json: dict, + ) -> None: + try: + await federation_client.post_signed(peer_onion, endpoint_path, payload_json) + except FederationClientError: + await metrics.inc("group_calls.signal.relay_failed") + + def _remote_servers_for_participants(self, rows: list[GroupCallParticipant]) -> set[str]: + remote_servers: set[str] = set() + for row in rows: + try: + parsed = parse_peer_address_with_policy(row.member_address, self.settings.tor_enabled) + except ValueError: + continue + if is_local_server_authority(parsed.server_onion, self.settings): + continue + remote_servers.add(parsed.server_onion) + return remote_servers + + async def _notify_call_state(self, session: AsyncSession, call: GroupCallSession) -> None: + participants = await self._participants(session, call.call_id) + payload = { + "type": "call.group.state", + "call_id": call.call_id, + "conversation_id": call.conversation_id, + "group_uid": call.group_uid, + "state": call.state, + "participants": self._participant_summary(participants), + "call_mode": "webrtc", + "max_participants": self.settings.group_call_max_participants, + } + for row in participants: + if row.local_user_id and row.state in self._delivery_states(): + await connection_manager.send_to_user(row.local_user_id, payload) + + async def _notify_participant_event( + self, + session: AsyncSession, + call: GroupCallSession, + member_address: str, + member_state: str, + ) -> None: + participants = await self._participants(session, call.call_id) + payload = { + "type": "call.group.participant", + "call_id": call.call_id, + "conversation_id": call.conversation_id, + "group_uid": call.group_uid, + "member_address": member_address, + "state": member_state, + } + for row in participants: + if row.local_user_id and row.state in self._delivery_states(): + await connection_manager.send_to_user(row.local_user_id, payload) + + async def offer(self, session: AsyncSession, caller_user_id: str, conversation_id: str) -> GroupCallSession: + await self._require_enabled() + caller = await self._user(session, caller_user_id) + caller_address = server_address_for_username(caller.username) + conversation = (await session.execute(select(Conversation).where(Conversation.id == conversation_id))).scalar_one_or_none() + if conversation is None or conversation.conversation_type != "group": + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Group conversation not found") + member_rows = list( + ( + await session.execute( + select(ConversationMember).where( + ConversationMember.conversation_id == conversation.id, + ConversationMember.status == "active", + ) + ) + ).scalars() + ) + caller_member = next((row for row in member_rows if row.member_user_id == caller_user_id), None) + if caller_member is None: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not a group member") + if len(member_rows) > self.settings.group_call_max_participants: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Too many active group members for mesh call") + + existing_call = ( + await session.execute( + select(GroupCallSession) + .where( + GroupCallSession.conversation_id == conversation.id, + GroupCallSession.state.in_(["ringing", "active"]), + ) + .order_by(GroupCallSession.started_at.desc()) + .limit(1) + ) + ).scalar_one_or_none() + if existing_call is not None: + now = datetime.now(UTC) + existing_participants = await self._participants(session, existing_call.call_id) + caller_participant = next( + (row for row in existing_participants if row.member_address == caller_address), + None, + ) + if caller_participant is None: + session.add( + GroupCallParticipant( + call_id=existing_call.call_id, + member_address=caller_address, + local_user_id=caller_user_id, + state="joined", + invited_at=now, + joined_at=now, + last_signal_at=now, + ) + ) + else: + caller_participant.state = "joined" + caller_participant.joined_at = now + caller_participant.left_at = None + caller_participant.last_signal_at = now + existing_call.state = "active" + await session.commit() + await metrics.inc("group_calls.joined") + await self._notify_call_state(session, existing_call) + await self._notify_participant_event(session, existing_call, caller_address, "joined") + participants_after = await self._participants(session, existing_call.call_id) + for peer_onion in self._remote_servers_for_participants(participants_after): + await self._post_federation( + peer_onion=peer_onion, + endpoint_path="/api/v2/federation/group-calls/join", + payload_json={ + "relay_id": str(uuid4()), + "call_id": existing_call.call_id, + "group_uid": existing_call.group_uid, + "member_address": caller_address, + }, + ) + return existing_call + + remote_servers = { + row.member_server_onion + for row in member_rows + if row.member_server_onion and not is_local_server_authority(row.member_server_onion, self.settings) + } + + now = datetime.now(UTC) + call = GroupCallSession( + call_id=str(uuid4()), + conversation_id=conversation.id, + group_uid=conversation.group_uid or "", + initiator_address=caller_address, + state="ringing", + started_at=now, + ring_expires_at=now + timedelta(seconds=self.settings.group_call_ring_ttl_seconds), + ) + session.add(call) + await session.flush() + + for member in member_rows: + participant_state = "joined" if member.member_address == caller_address else "ringing" + if member.member_user_id and participant_state == "ringing": + if not await connection_manager.has_user(member.member_user_id): + participant_state = "offline_pending" + await metrics.inc("group_calls.offline_pending") + session.add( + GroupCallParticipant( + call_id=call.call_id, + member_address=member.member_address, + local_user_id=member.member_user_id, + state=participant_state, + invited_at=now, + joined_at=now if participant_state == "joined" else None, + ) + ) + + await session.commit() + await metrics.inc("group_calls.created") + await self._notify_call_state(session, call) + + participants = await self._participants(session, call.call_id) + for row in participants: + if row.local_user_id and row.member_address != caller_address: + await connection_manager.send_to_user( + row.local_user_id, + { + "type": "call.group.incoming", + "call_id": call.call_id, + "conversation_id": conversation.id, + "group_uid": call.group_uid, + "from_user_address": caller_address, + "call_mode": "webrtc", + "max_participants": self.settings.group_call_max_participants, + }, + ) + for peer_onion in remote_servers: + await self._post_federation( + peer_onion=peer_onion, + endpoint_path="/api/v2/federation/group-calls/offer", + payload_json={ + "relay_id": str(uuid4()), + "call_id": call.call_id, + "group_uid": call.group_uid, + "conversation_id": call.conversation_id, + "from_user_address": caller_address, + "call_schema_version": 1, + "call_mode": "webrtc", + "max_participants": self.settings.group_call_max_participants, + "ring_expires_at": call.ring_expires_at.isoformat(), + }, + ) + return call + + async def join(self, session: AsyncSession, user_id: str, call_id: str) -> GroupCallSession: + await self._require_enabled() + user = await self._user(session, user_id) + address = server_address_for_username(user.username) + call = await self._call(session, call_id) + row = ( + await session.execute( + select(GroupCallParticipant).where( + GroupCallParticipant.call_id == call.call_id, + GroupCallParticipant.member_address == address, + ) + ) + ).scalar_one_or_none() + if row is None: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not a call participant") + now = datetime.now(UTC) + row.state = "joined" + row.joined_at = now + row.left_at = None + row.last_signal_at = now + call.state = "active" + await session.commit() + await metrics.inc("group_calls.joined") + await self._notify_call_state(session, call) + await self._notify_participant_event(session, call, address, "joined") + participants = await self._participants(session, call.call_id) + for peer_onion in self._remote_servers_for_participants(participants): + await self._post_federation( + peer_onion=peer_onion, + endpoint_path="/api/v2/federation/group-calls/join", + payload_json={ + "relay_id": str(uuid4()), + "call_id": call.call_id, + "group_uid": call.group_uid, + "member_address": address, + }, + ) + return call + + async def reject(self, session: AsyncSession, user_id: str, call_id: str) -> GroupCallSession: + await self._require_enabled() + user = await self._user(session, user_id) + address = server_address_for_username(user.username) + call = await self._call(session, call_id) + row = ( + await session.execute( + select(GroupCallParticipant).where( + GroupCallParticipant.call_id == call.call_id, + GroupCallParticipant.member_address == address, + ) + ) + ).scalar_one_or_none() + if row is None: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not a call participant") + row.state = "declined" + row.left_at = datetime.now(UTC) + await session.commit() + if row.local_user_id: + await connection_manager.send_to_user( + row.local_user_id, + {"type": "call.group.ended", "call_id": call.call_id, "reason": "declined"}, + ) + await self._notify_call_state(session, call) + await self._notify_participant_event(session, call, address, "declined") + participants = await self._participants(session, call.call_id) + for peer_onion in self._remote_servers_for_participants(participants): + await self._post_federation( + peer_onion=peer_onion, + endpoint_path="/api/v2/federation/group-calls/leave", + payload_json={ + "relay_id": str(uuid4()), + "call_id": call.call_id, + "group_uid": call.group_uid, + "member_address": address, + "reason": "declined", + }, + ) + return call + + async def leave(self, session: AsyncSession, user_id: str, call_id: str, reason: str = "left") -> GroupCallSession: + await self._require_enabled() + user = await self._user(session, user_id) + address = server_address_for_username(user.username) + call = await self._call(session, call_id) + row = ( + await session.execute( + select(GroupCallParticipant).where( + GroupCallParticipant.call_id == call.call_id, + GroupCallParticipant.member_address == address, + ) + ) + ).scalar_one_or_none() + if row is None: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not a call participant") + now = datetime.now(UTC) + row.state = "left" + row.left_at = now + row.last_signal_at = now + participants = await self._participants(session, call.call_id) + active_remaining = [ + item for item in participants if item.member_address != address and item.state in {"joined", "ringing", "offline_pending"} + ] + if not active_remaining: + call.state = "ended" + call.ended_at = now + self._clear_audio_state_for_call(call.call_id) + await metrics.inc("group_calls.ended") + await session.commit() + await metrics.inc("group_calls.left") + + if row.local_user_id: + await connection_manager.send_to_user( + row.local_user_id, + {"type": "call.group.ended", "call_id": call.call_id, "reason": reason or "left"}, + ) + + await self._notify_call_state(session, call) + await self._notify_participant_event(session, call, address, "left") + remote_servers = self._remote_servers_for_participants(participants) + for peer_onion in remote_servers: + await self._post_federation( + peer_onion=peer_onion, + endpoint_path="/api/v2/federation/group-calls/leave", + payload_json={ + "relay_id": str(uuid4()), + "call_id": call.call_id, + "group_uid": call.group_uid, + "member_address": address, + "reason": reason, + }, + ) + if call.state == "ended": + for participant in participants: + if participant.local_user_id: + await connection_manager.send_to_user( + participant.local_user_id, + {"type": "call.group.ended", "call_id": call.call_id, "reason": reason}, + ) + for peer_onion in remote_servers: + await self._post_federation( + peer_onion=peer_onion, + endpoint_path="/api/v2/federation/group-calls/end", + payload_json={ + "relay_id": str(uuid4()), + "call_id": call.call_id, + "group_uid": call.group_uid, + "reason": reason, + }, + ) + return call + + async def handle_disconnect(self, session: AsyncSession, user_id: str) -> None: + if not self.settings.enable_group_call_v2c: + return + joined_stmt = ( + select(GroupCallParticipant.call_id) + .where( + GroupCallParticipant.local_user_id == user_id, + GroupCallParticipant.state == "joined", + ) + .distinct() + ) + joined_call_ids = [str(item) for item in (await session.execute(joined_stmt)).scalars().all()] + for call_id in joined_call_ids: + try: + await self.leave(session, user_id, call_id, reason="peer_disconnected") + except HTTPException: + continue + + now = datetime.now(UTC) + ringing_rows = list( + ( + await session.execute( + select(GroupCallParticipant, GroupCallSession) + .join(GroupCallSession, GroupCallSession.call_id == GroupCallParticipant.call_id) + .where( + GroupCallParticipant.local_user_id == user_id, + GroupCallParticipant.state.in_(["ringing", "offline_pending"]), + GroupCallSession.state == "ringing", + ) + ) + ).all() + ) + touched_calls: set[str] = set() + for participant, _ in ringing_rows: + participant.state = "offline_pending" + participant.last_signal_at = now + touched_calls.add(participant.call_id) + if ringing_rows: + await session.commit() + for call_id in touched_calls: + try: + call = await self._call(session, call_id) + except HTTPException: + continue + await self._notify_call_state(session, call) + + async def replay_pending_for_user(self, session: AsyncSession, user_id: str) -> None: + await self._require_enabled() + user = await self._user(session, user_id) + address = server_address_for_username(user.username) + now = datetime.now(UTC) + rows = list( + ( + await session.execute( + select(GroupCallParticipant, GroupCallSession) + .join(GroupCallSession, GroupCallSession.call_id == GroupCallParticipant.call_id) + .where( + GroupCallParticipant.member_address == address, + GroupCallParticipant.local_user_id == user_id, + GroupCallParticipant.state == "offline_pending", + GroupCallSession.state == "ringing", + GroupCallSession.ring_expires_at > now, + ) + ) + ).all() + ) + for participant, call in rows: + participant.state = "ringing" + await connection_manager.send_to_user( + user_id, + { + "type": "call.group.incoming", + "call_id": call.call_id, + "conversation_id": call.conversation_id, + "group_uid": call.group_uid, + "from_user_address": call.initiator_address, + "call_mode": "webrtc", + "max_participants": self.settings.group_call_max_participants, + }, + ) + if rows: + await session.commit() + + async def route_webrtc_signal( + self, + session: AsyncSession, + *, + sender_user_id: str, + call_id: str, + event_type: str, + body: dict, + target_user_address: str, + ) -> None: + await self._require_enabled() + call = await self._call(session, call_id) + sender_user = await self._user(session, sender_user_id) + sender_address = server_address_for_username(sender_user.username) + participants = await self._participants(session, call_id) + sender_participant = next((item for item in participants if item.member_address == sender_address), None) + if sender_participant is None: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not a call participant") + sender_participant.last_signal_at = datetime.now(UTC) + target = target_user_address.strip().lower() + target_participant = next((item for item in participants if item.member_address == target), None) + if target_participant is None: + await metrics.inc("group_calls.signal.rejected") + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Target is not in this call") + payload = dict(body) + payload["type"] = event_type + payload["call_id"] = call_id + payload["source_user_address"] = sender_address + payload["target_user_address"] = target + if target_participant.local_user_id: + await connection_manager.send_to_user(target_participant.local_user_id, payload) + return + parsed = parse_peer_address_with_policy(target, self.settings.tor_enabled) + if is_local_server_authority(parsed.server_onion, self.settings): + await metrics.inc("group_calls.signal.rejected") + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Target user is unavailable") + endpoint_path = "" + federation_payload: dict = { + "relay_id": str(uuid4()), + "call_id": call.call_id, + "group_uid": call.group_uid, + "source_user_address": sender_address, + "target_user_address": target, + "call_schema_version": int(body.get("call_schema_version", 1)), + "call_mode": str(body.get("call_mode", "webrtc")), + "max_participants": int(body.get("max_participants", self.settings.group_call_max_participants)), + } + if event_type == "call.webrtc.offer": + endpoint_path = "/api/v2/federation/group-calls/webrtc-offer" + federation_payload["sdp"] = str(body.get("sdp", "")) + elif event_type == "call.webrtc.answer": + endpoint_path = "/api/v2/federation/group-calls/webrtc-answer" + federation_payload["sdp"] = str(body.get("sdp", "")) + elif event_type == "call.webrtc.ice": + endpoint_path = "/api/v2/federation/group-calls/webrtc-ice" + federation_payload["candidate"] = str(body.get("candidate", "")) + federation_payload["sdp_mid"] = body.get("sdp_mid") + federation_payload["sdp_mline_index"] = body.get("sdp_mline_index") + else: + await metrics.inc("group_calls.signal.rejected") + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Unsupported signaling event") + try: + await federation_client.post_signed(parsed.server_onion, endpoint_path, federation_payload) + except FederationClientError as exc: + await metrics.inc("group_calls.signal.relay_failed") + raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=exc.detail) from exc + + async def relay_offer(self, session: AsyncSession, payload: FederationGroupCallOfferRequestV2) -> None: + await self._require_enabled() + existing = (await session.execute(select(GroupCallSession).where(GroupCallSession.call_id == payload.call_id))).scalar_one_or_none() + if existing is not None: + return + conversation = (await session.execute(select(Conversation).where(Conversation.group_uid == payload.group_uid))).scalar_one_or_none() + if conversation is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Group not found") + parsed = parse_peer_address_with_policy(payload.from_user_address, self.settings.tor_enabled) + if is_local_server_authority(parsed.server_onion, self.settings): + return + call = GroupCallSession( + call_id=payload.call_id, + conversation_id=conversation.id, + group_uid=payload.group_uid, + initiator_address=payload.from_user_address, + state="ringing", + ring_expires_at=datetime.fromisoformat(payload.ring_expires_at.replace("Z", "+00:00")), + started_at=datetime.now(UTC), + ) + session.add(call) + members = list( + ( + await session.execute( + select(ConversationMember).where( + ConversationMember.conversation_id == conversation.id, + ConversationMember.status == "active", + ) + ) + ).scalars() + ) + for member in members: + state_value = "joined" if member.member_address == payload.from_user_address else "ringing" + if member.member_user_id and state_value == "ringing" and not await connection_manager.has_user(member.member_user_id): + state_value = "offline_pending" + session.add( + GroupCallParticipant( + call_id=call.call_id, + member_address=member.member_address, + local_user_id=member.member_user_id, + state=state_value, + invited_at=datetime.now(UTC), + ) + ) + await session.commit() + await self._notify_call_state(session, call) + participants = await self._participants(session, call.call_id) + for row in participants: + if row.local_user_id and row.member_address != payload.from_user_address: + await connection_manager.send_to_user( + row.local_user_id, + { + "type": "call.group.incoming", + "call_id": call.call_id, + "conversation_id": conversation.id, + "group_uid": payload.group_uid, + "from_user_address": payload.from_user_address, + "call_mode": "webrtc", + "max_participants": self.settings.group_call_max_participants, + }, + ) + + async def relay_join(self, session: AsyncSession, payload: FederationGroupCallJoinRequestV2) -> None: + await self._require_enabled() + call = await self._call(session, payload.call_id) + participant = ( + await session.execute( + select(GroupCallParticipant).where( + GroupCallParticipant.call_id == call.call_id, + GroupCallParticipant.member_address == payload.member_address, + ) + ) + ).scalar_one_or_none() + if participant is None: + return + participant.state = "joined" + participant.joined_at = datetime.now(UTC) + call.state = "active" + await session.commit() + await self._notify_call_state(session, call) + await self._notify_participant_event(session, call, payload.member_address, "joined") + + async def relay_leave(self, session: AsyncSession, payload: FederationGroupCallLeaveRequestV2) -> None: + await self._require_enabled() + call = await self._call(session, payload.call_id) + participant = ( + await session.execute( + select(GroupCallParticipant).where( + GroupCallParticipant.call_id == call.call_id, + GroupCallParticipant.member_address == payload.member_address, + ) + ) + ).scalar_one_or_none() + if participant is None: + return + participant.state = "left" + participant.left_at = datetime.now(UTC) + participants = await self._participants(session, call.call_id) + remaining = [row for row in participants if row.member_address != payload.member_address and row.state in {"joined", "ringing", "offline_pending"}] + if not remaining: + call.state = "ended" + call.ended_at = datetime.now(UTC) + self._clear_audio_state_for_call(call.call_id) + await session.commit() + await self._notify_call_state(session, call) + await self._notify_participant_event(session, call, payload.member_address, "left") + if call.state == "ended": + for row in participants: + if row.local_user_id: + await connection_manager.send_to_user( + row.local_user_id, + {"type": "call.group.ended", "call_id": call.call_id, "reason": payload.reason}, + ) + + async def relay_end(self, session: AsyncSession, payload: FederationGroupCallEndRequestV2) -> None: + await self._require_enabled() + call = await self._call(session, payload.call_id) + call.state = "ended" + call.ended_at = datetime.now(UTC) + self._clear_audio_state_for_call(call.call_id) + await session.commit() + rows = await self._participants(session, call.call_id) + for row in rows: + if row.local_user_id: + await connection_manager.send_to_user( + row.local_user_id, + {"type": "call.group.ended", "call_id": call.call_id, "reason": payload.reason}, + ) + + async def audio(self, session: AsyncSession, user_id: str, payload: CallAudioRequest) -> None: + await self._require_enabled() + if not self.settings.enable_legacy_call_audio_ws: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="WS audio transport is disabled") + try: + pcm = base64.b64decode(payload.pcm_b64.encode("utf-8"), validate=True) + except binascii.Error as exc: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid audio payload encoding") from exc + if not pcm: + return + if len(pcm) > self.settings.voice_audio_max_chunk_bytes: + raise HTTPException(status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, detail="Audio chunk too large") + + call = await self._call(session, payload.call_id) + if call.state != "active": + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Call is not active") + + sender_participant = ( + await session.execute( + select(GroupCallParticipant).where( + GroupCallParticipant.call_id == call.call_id, + GroupCallParticipant.local_user_id == user_id, + ) + ) + ).scalar_one_or_none() + if sender_participant is None or sender_participant.state != "joined": + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not an active call participant") + sender_address = sender_participant.member_address + + now_mono = time.monotonic() + min_interval = max(0.001, float(self.settings.voice_audio_min_interval_ms) / 1000.0) + rate_key = (call.call_id, sender_address) + previous = self._audio_last_at.get(rate_key) + if previous is not None and (now_mono - previous) < min_interval: + return + self._audio_last_at[rate_key] = now_mono + + recipients = list( + ( + await session.execute( + select(GroupCallParticipant.local_user_id).where( + GroupCallParticipant.call_id == call.call_id, + GroupCallParticipant.state == "joined", + GroupCallParticipant.member_address != sender_address, + GroupCallParticipant.local_user_id.is_not(None), + ) + ) + ).scalars() + ) + out_payload = { + "type": "call.audio", + "call_id": call.call_id, + "from_user_id": user_id, + "from_user_address": sender_address, + "sequence": payload.sequence, + "pcm_b64": payload.pcm_b64, + } + if recipients: + await asyncio.gather( + *(connection_manager.send_to_user(recipient_user_id, out_payload) for recipient_user_id in recipients), + return_exceptions=True, + ) + + async def _relay_webrtc_to_local_target( + self, + session: AsyncSession, + *, + call_id: str, + source_user_address: str, + target_user_address: str, + event_type: str, + payload: dict, + ) -> None: + await self._require_enabled() + call = await self._call(session, call_id) + participants = await self._participants(session, call.call_id) + source = source_user_address.strip().lower() + target = target_user_address.strip().lower() + source_participant = next((item for item in participants if item.member_address == source), None) + target_participant = next((item for item in participants if item.member_address == target), None) + if source_participant is None or target_participant is None: + await metrics.inc("group_calls.signal.rejected") + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid signaling participants") + if not target_participant.local_user_id: + await metrics.inc("group_calls.signal.rejected") + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Target user is unavailable") + outgoing = dict(payload) + outgoing["type"] = event_type + outgoing["call_id"] = call_id + outgoing["source_user_address"] = source + outgoing["target_user_address"] = target + await connection_manager.send_to_user(target_participant.local_user_id, outgoing) + + async def relay_webrtc_offer(self, session: AsyncSession, payload: FederationGroupCallWebRtcOfferRequestV2) -> None: + await self._relay_webrtc_to_local_target( + session, + call_id=payload.call_id, + source_user_address=payload.source_user_address, + target_user_address=payload.target_user_address, + event_type="call.webrtc.offer", + payload={ + "sdp": payload.sdp, + "call_schema_version": payload.call_schema_version, + "call_mode": payload.call_mode, + "max_participants": payload.max_participants, + }, + ) + + async def relay_webrtc_answer(self, session: AsyncSession, payload: FederationGroupCallWebRtcAnswerRequestV2) -> None: + await self._relay_webrtc_to_local_target( + session, + call_id=payload.call_id, + source_user_address=payload.source_user_address, + target_user_address=payload.target_user_address, + event_type="call.webrtc.answer", + payload={ + "sdp": payload.sdp, + "call_schema_version": payload.call_schema_version, + "call_mode": payload.call_mode, + "max_participants": payload.max_participants, + }, + ) + + async def relay_webrtc_ice(self, session: AsyncSession, payload: FederationGroupCallWebRtcIceRequestV2) -> None: + await self._relay_webrtc_to_local_target( + session, + call_id=payload.call_id, + source_user_address=payload.source_user_address, + target_user_address=payload.target_user_address, + event_type="call.webrtc.ice", + payload={ + "candidate": payload.candidate, + "sdp_mid": payload.sdp_mid, + "sdp_mline_index": payload.sdp_mline_index, + "call_schema_version": payload.call_schema_version, + "call_mode": payload.call_mode, + "max_participants": payload.max_participants, + }, + ) + + +group_call_service = GroupCallService() diff --git a/server/app/services/group_conversation_service.py b/server/app/services/group_conversation_service.py new file mode 100644 index 0000000..f42a361 --- /dev/null +++ b/server/app/services/group_conversation_service.py @@ -0,0 +1,1192 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import UTC, datetime +from uuid import uuid4 + +from fastapi import HTTPException, status +from sqlalchemy import delete, func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.config import get_settings +from app.models.conversation import Conversation +from app.models.conversation_member import ConversationMember +from app.models.group_membership_event import GroupMembershipEvent +from app.models.user import User +from app.schemas.v2_conversation import ConversationMemberOutV2, ConversationOutV2, ConversationRecipientsOutV2 +from app.schemas.v2_device import DeviceOutV2 +from app.schemas.v2_federation import FederationGroupEventRequestV2, FederationGroupSnapshotOutV2 +from app.services.device_service_v2 import device_service_v2 +from app.services.federation_client import FederationClientError, federation_client +from app.services.federation_outbox_service import federation_outbox_service +from app.services.metrics import metrics +from app.services.peer_address import parse_peer_address_with_policy +from app.services.server_authority import is_local_server_authority +from app.services.server_identity import get_server_onion, server_address_for_username +from app.ws.manager import connection_manager + + +@dataclass(slots=True) +class GroupEventRecord: + group_uid: str + event_seq: int + event_type: str + actor_address: str + target_address: str + payload: dict + created_at: datetime + + +class GroupConversationService: + def __init__(self) -> None: + self.settings = get_settings() + + async def _notify_group_renamed_local( + self, + session: AsyncSession, + *, + conversation: Conversation, + actor_address: str, + event_seq: int = 0, + ) -> None: + members = await self._members_for_conversation(session, conversation.id) + local_user_ids: set[str] = set() + for row in members: + if row.member_user_id is None: + continue + if row.status not in {"active", "invited"}: + continue + local_user_ids.add(row.member_user_id) + if not local_user_ids: + return + + payload = { + "type": "conversation.group.renamed", + "conversation_id": conversation.id, + "group_uid": conversation.group_uid, + "group_name": conversation.group_name, + "actor_address": actor_address, + "event_seq": int(event_seq or 0), + } + for user_id in local_user_ids: + await connection_manager.send_to_user(user_id, payload) + + @staticmethod + def _member_identity_key(member_user_id: str | None, member_address: str) -> str: + if member_user_id: + return f"user:{member_user_id}" + return f"addr:{member_address}" + + async def _ensure_owner_member_row(self, session: AsyncSession, conversation: Conversation) -> None: + if conversation.conversation_type != "group": + return + owner_address = (conversation.owner_address or "").strip().lower() + if not owner_address: + return + stmt = select(ConversationMember).where( + ConversationMember.conversation_id == conversation.id, + ConversationMember.member_address == owner_address, + ) + existing = (await session.execute(stmt)).scalar_one_or_none() + if existing is not None: + if existing.member_user_id is None or existing.role != "owner" or existing.status != "active": + try: + member_user_id, member_server_onion = await self._resolve_member_user_id( + session, + owner_address, + ) + existing.member_user_id = member_user_id + if member_server_onion: + existing.member_server_onion = member_server_onion + except HTTPException: + pass + existing.role = "owner" + existing.status = "active" + if existing.joined_at is None: + existing.joined_at = datetime.now(UTC) + existing.left_at = None + existing.updated_at = datetime.now(UTC) + await session.flush() + return + + now = datetime.now(UTC) + member_user_id: str | None = None + member_server_onion = "" + try: + member_user_id, member_server_onion = await self._resolve_member_user_id( + session, + owner_address, + ) + except HTTPException: + try: + parsed = parse_peer_address_with_policy(owner_address, self.settings.tor_enabled) + member_server_onion = parsed.server_onion + except ValueError: + member_server_onion = "" + + session.add( + ConversationMember( + conversation_id=conversation.id, + member_user_id=member_user_id, + member_address=owner_address, + member_server_onion=member_server_onion, + role="owner", + status="active", + invited_by_address=owner_address, + invited_at=now, + joined_at=now, + updated_at=now, + ) + ) + await session.flush() + + async def _next_event_seq(self, session: AsyncSession, group_uid: str) -> int: + stmt = select(func.max(GroupMembershipEvent.event_seq)).where(GroupMembershipEvent.group_uid == group_uid) + current = (await session.execute(stmt)).scalar_one_or_none() + return 1 if current is None else int(current) + 1 + + async def _current_event_seq(self, session: AsyncSession, group_uid: str) -> int: + stmt = select(func.max(GroupMembershipEvent.event_seq)).where(GroupMembershipEvent.group_uid == group_uid) + current = (await session.execute(stmt)).scalar_one_or_none() + return int(current or 0) + + async def _event_exists(self, session: AsyncSession, group_uid: str, event_seq: int) -> bool: + stmt = ( + select(GroupMembershipEvent.id) + .where( + GroupMembershipEvent.group_uid == group_uid, + GroupMembershipEvent.event_seq == event_seq, + ) + .limit(1) + ) + return (await session.execute(stmt)).scalar_one_or_none() is not None + + async def _append_event( + self, + session: AsyncSession, + *, + group_uid: str, + event_seq: int, + event_type: str, + actor_address: str, + target_address: str, + payload: dict, + created_at: datetime | None = None, + ) -> GroupMembershipEvent: + row = GroupMembershipEvent( + group_uid=group_uid, + event_seq=event_seq, + event_type=event_type, + actor_address=actor_address, + target_address=target_address, + payload_json=payload, + created_at=created_at or datetime.now(UTC), + ) + session.add(row) + await session.flush() + return row + + async def _serialize_event( + self, + session: AsyncSession, + *, + conversation: Conversation, + event_type: str, + actor_address: str, + target_address: str, + payload: dict, + created_at: datetime | None = None, + ) -> GroupEventRecord: + if not conversation.group_uid: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Group conversation missing group_uid") + seq = await self._next_event_seq(session, conversation.group_uid) + row = await self._append_event( + session, + group_uid=conversation.group_uid, + event_seq=seq, + event_type=event_type, + actor_address=actor_address, + target_address=target_address, + payload=payload, + created_at=created_at, + ) + return GroupEventRecord( + group_uid=row.group_uid, + event_seq=row.event_seq, + event_type=row.event_type, + actor_address=row.actor_address, + target_address=row.target_address, + payload=row.payload_json, + created_at=row.created_at, + ) + + async def _members_for_conversation(self, session: AsyncSession, conversation_id: str) -> list[ConversationMember]: + stmt = ( + select(ConversationMember) + .where(ConversationMember.conversation_id == conversation_id) + .order_by(ConversationMember.invited_at.asc()) + ) + rows = list((await session.execute(stmt)).scalars().all()) + by_address: dict[str, ConversationMember] = {} + for row in rows: + existing = by_address.get(row.member_address) + if existing is None or row.updated_at >= existing.updated_at: + by_address[row.member_address] = row + return list(by_address.values()) + + async def _member_for_user( + self, + session: AsyncSession, + conversation_id: str, + user_id: str, + ) -> ConversationMember | None: + stmt = ( + select(ConversationMember) + .where( + ConversationMember.conversation_id == conversation_id, + ConversationMember.member_user_id == user_id, + ) + .order_by(ConversationMember.role.desc(), ConversationMember.status.asc(), ConversationMember.updated_at.desc()) + ) + return (await session.execute(stmt)).scalars().first() + + async def ensure_member_access( + self, + session: AsyncSession, + conversation: Conversation, + user_id: str, + *, + require_active: bool = True, + ) -> ConversationMember: + await self._ensure_owner_member_row(session, conversation) + member = await self._member_for_user(session, conversation.id, user_id) + if member is None: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not a conversation member") + if require_active and member.status != "active": + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not an active conversation member") + return member + + async def ensure_owner(self, session: AsyncSession, conversation: Conversation, user_id: str) -> ConversationMember: + member = await self.ensure_member_access(session, conversation, user_id, require_active=True) + if member.role != "owner": + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Only owner can perform this action") + return member + + def _normalize_name(self, value: str) -> str: + normalized = value.strip() + if not normalized: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Group name is required") + return normalized + + def _normalize_member_address(self, value: str) -> str: + try: + parsed = parse_peer_address_with_policy(value, self.settings.tor_enabled) + except ValueError as exc: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc + return parsed.canonical + + async def _resolve_member_user_id( + self, + session: AsyncSession, + member_address: str, + request_authority: str | None = None, + ) -> tuple[str | None, str]: + parsed = parse_peer_address_with_policy(member_address, self.settings.tor_enabled) + additional_aliases = {request_authority} if request_authority else None + if is_local_server_authority(parsed.server_onion, self.settings, additional_aliases): + user_stmt = select(User).where(User.username == parsed.username, User.disabled_at.is_(None)) + user = (await session.execute(user_stmt)).scalar_one_or_none() + if user is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Local user not found for {parsed.canonical}", + ) + return user.id, parsed.server_onion + return None, parsed.server_onion + + async def _broadcast_events( + self, + session: AsyncSession, + conversation: Conversation, + events: list[GroupEventRecord], + ) -> None: + if conversation.origin_server_onion != get_server_onion(): + return + members = await self._members_for_conversation(session, conversation.id) + remote_servers = { + member.member_server_onion + for member in members + if member.member_server_onion + and not is_local_server_authority(member.member_server_onion, self.settings) + and member.status in {"invited", "active"} + } + for peer_onion in remote_servers: + for event in events: + payload = { + "relay_id": str(uuid4()), + "group_uid": event.group_uid, + "event_seq": event.event_seq, + "event_type": event.event_type, + "actor_address": event.actor_address, + "target_address": event.target_address, + "payload": event.payload, + "created_at": event.created_at.isoformat(), + } + outbox_item = await federation_outbox_service.enqueue( + session, + peer_onion=peer_onion, + event_type="group.event", + endpoint_path="/api/v2/federation/groups/events", + payload_json=payload, + dedupe_key=f"group-event:{event.group_uid}:{event.event_seq}:{peer_onion}", + ) + await federation_outbox_service.deliver_item(session, outbox_item.id) + + async def create_group( + self, + session: AsyncSession, + owner: User, + *, + name: str, + member_addresses: list[str], + request_authority: str | None = None, + ) -> Conversation: + if not self.settings.enable_group_dm_v2c: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Group conversations are disabled") + owner_address = server_address_for_username(owner.username) + normalized_name = self._normalize_name(name) + seen: set[str] = {owner_address} + seen_identity_keys: set[str] = {self._member_identity_key(owner.id, owner_address)} + normalized_members: list[str] = [] + normalized_member_identity_keys: list[str] = [] + for raw in member_addresses: + canonical = self._normalize_member_address(raw) + if canonical in seen: + continue + member_user_id, _ = await self._resolve_member_user_id( + session, + canonical, + request_authority=request_authority, + ) + if member_user_id == owner.id: + continue + member_identity_key = self._member_identity_key(member_user_id, canonical) + if member_identity_key in seen_identity_keys: + continue + seen.add(canonical) + seen_identity_keys.add(member_identity_key) + normalized_members.append(canonical) + normalized_member_identity_keys.append(member_identity_key) + if 1 + len(normalized_member_identity_keys) > self.settings.group_max_members: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=( + f"Group member limit exceeded " + f"(requested_total={1 + len(normalized_member_identity_keys)}, " + f"max={self.settings.group_max_members})" + ), + ) + + now = datetime.now(UTC) + conversation = Conversation( + kind="group", + conversation_type="group", + group_uid=uuid4().hex, + group_name=normalized_name, + origin_server_onion=get_server_onion(), + owner_address=owner_address, + created_at=now, + ) + session.add(conversation) + await session.flush() + session.add( + ConversationMember( + conversation_id=conversation.id, + member_user_id=owner.id, + member_address=owner_address, + member_server_onion=get_server_onion(), + role="owner", + status="active", + invited_by_address=owner_address, + invited_at=now, + joined_at=now, + updated_at=now, + ) + ) + invite_rows: list[ConversationMember] = [] + for address in normalized_members: + member_user_id, member_server = await self._resolve_member_user_id( + session, + address, + request_authority=request_authority, + ) + if member_user_id == owner.id: + continue + row = ConversationMember( + conversation_id=conversation.id, + member_user_id=member_user_id, + member_address=address, + member_server_onion=member_server, + role="member", + status="invited", + invited_by_address=owner_address, + invited_at=now, + updated_at=now, + ) + session.add(row) + invite_rows.append(row) + + events = [ + await self._serialize_event( + session, + conversation=conversation, + event_type="create", + actor_address=owner_address, + target_address="", + payload={ + "conversation_id": conversation.id, + "group_uid": conversation.group_uid, + "group_name": conversation.group_name, + "origin_server_onion": conversation.origin_server_onion, + "owner_address": conversation.owner_address, + }, + created_at=now, + ) + ] + for row in invite_rows: + events.append( + await self._serialize_event( + session, + conversation=conversation, + event_type="invite", + actor_address=owner_address, + target_address=row.member_address, + payload={ + "conversation_id": conversation.id, + "group_uid": conversation.group_uid, + "member_address": row.member_address, + "member_server_onion": row.member_server_onion, + "status": "invited", + "role": row.role, + }, + created_at=now, + ) + ) + await self._broadcast_events(session, conversation, events) + await session.commit() + await metrics.inc("groups.created") + return conversation + + async def list_members( + self, + session: AsyncSession, + conversation: Conversation, + user_id: str, + ) -> list[ConversationMember]: + await self.ensure_member_access(session, conversation, user_id, require_active=False) + return await self._members_for_conversation(session, conversation.id) + + async def invite_members( + self, + session: AsyncSession, + *, + conversation: Conversation, + owner: User, + member_addresses: list[str], + request_authority: str | None = None, + ) -> list[ConversationMember]: + await self.ensure_owner(session, conversation, owner.id) + if conversation.origin_server_onion != get_server_onion(): + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Only origin server can modify membership") + current = await self._members_for_conversation(session, conversation.id) + by_address = {row.member_address: row for row in current} + active_identity_keys: set[str] = set() + for row in current: + if row.status not in {"active", "invited"}: + continue + active_identity_keys.add(self._member_identity_key(row.member_user_id, row.member_address)) + active_count = len(active_identity_keys) + owner_address = server_address_for_username(owner.username) + now = datetime.now(UTC) + rows: list[ConversationMember] = [] + events: list[GroupEventRecord] = [] + for raw in member_addresses: + canonical = self._normalize_member_address(raw) + if canonical == owner_address: + continue + existing = by_address.get(canonical) + if existing is not None and existing.member_user_id == owner.id: + continue + if existing is not None and existing.status in {"active", "invited"}: + continue + if existing is None: + member_user_id, member_server = await self._resolve_member_user_id( + session, + canonical, + request_authority=request_authority, + ) + if member_user_id == owner.id: + continue + member_identity_key = self._member_identity_key(member_user_id, canonical) + if member_identity_key in active_identity_keys: + continue + if active_count + 1 > self.settings.group_max_members: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=( + "Group member limit exceeded " + f"(active_or_invited={active_count}, max={self.settings.group_max_members})" + ), + ) + active_identity_keys.add(member_identity_key) + active_count += 1 + existing = ConversationMember( + conversation_id=conversation.id, + member_user_id=member_user_id, + member_address=canonical, + member_server_onion=member_server, + role="member", + status="invited", + invited_by_address=owner_address, + invited_at=now, + updated_at=now, + ) + session.add(existing) + else: + member_identity_key = self._member_identity_key(existing.member_user_id, existing.member_address) + if member_identity_key not in active_identity_keys: + if active_count + 1 > self.settings.group_max_members: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=( + "Group member limit exceeded " + f"(active_or_invited={active_count}, max={self.settings.group_max_members})" + ), + ) + active_identity_keys.add(member_identity_key) + active_count += 1 + existing.status = "invited" + existing.role = "member" + existing.invited_by_address = owner_address + existing.invited_at = now + existing.joined_at = None + existing.left_at = None + existing.updated_at = now + rows.append(existing) + events.append( + await self._serialize_event( + session, + conversation=conversation, + event_type="invite", + actor_address=owner_address, + target_address=canonical, + payload={ + "conversation_id": conversation.id, + "group_uid": conversation.group_uid, + "member_address": canonical, + "member_server_onion": existing.member_server_onion, + "status": "invited", + "role": "member", + }, + created_at=now, + ) + ) + if events: + await self._broadcast_events(session, conversation, events) + await session.commit() + await metrics.inc("groups.invited", len(events)) + return rows + + async def remove_member( + self, + session: AsyncSession, + *, + conversation: Conversation, + owner: User, + member_address: str, + ) -> ConversationMember: + await self.ensure_owner(session, conversation, owner.id) + target = self._normalize_member_address(member_address) + owner_address = server_address_for_username(owner.username) + if target == owner_address: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Owner cannot remove self") + stmt = select(ConversationMember).where( + ConversationMember.conversation_id == conversation.id, + ConversationMember.member_address == target, + ) + row = (await session.execute(stmt)).scalar_one_or_none() + if row is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Member not found") + now = datetime.now(UTC) + row.status = "removed" + row.left_at = now + row.updated_at = now + event = await self._serialize_event( + session, + conversation=conversation, + event_type="remove", + actor_address=owner_address, + target_address=target, + payload={"conversation_id": conversation.id, "group_uid": conversation.group_uid, "member_address": target}, + created_at=now, + ) + await self._broadcast_events(session, conversation, [event]) + await session.commit() + await metrics.inc("groups.removed") + return row + + async def leave( + self, + session: AsyncSession, + *, + conversation: Conversation, + user: User, + reason: str | None = None, + ) -> ConversationMember: + member = await self.ensure_member_access(session, conversation, user.id, require_active=False) + now = datetime.now(UTC) + actor_address = (member.member_address or server_address_for_username(user.username)).strip().lower() + leave_reason = (reason or "").strip() or "left" + + if member.role == "owner" and member.status in {"active", "invited"}: + eligible_stmt = ( + select(ConversationMember) + .where( + ConversationMember.conversation_id == conversation.id, + ConversationMember.role != "owner", + ConversationMember.status.in_(("active", "invited")), + ) + .order_by(ConversationMember.invited_at.asc(), ConversationMember.member_address.asc()) + .limit(1) + ) + replacement = (await session.execute(eligible_stmt)).scalar_one_or_none() + + if replacement is not None: + member.role = "member" + member.status = "left" + member.left_at = now + member.updated_at = now + + replacement.role = "owner" + replacement.status = "active" + if replacement.joined_at is None: + replacement.joined_at = now + replacement.left_at = None + replacement.updated_at = now + conversation.owner_address = replacement.member_address + + if conversation.origin_server_onion == get_server_onion(): + transfer_event = await self._serialize_event( + session, + conversation=conversation, + event_type="transfer_owner", + actor_address=actor_address, + target_address=replacement.member_address, + payload={ + "conversation_id": conversation.id, + "group_uid": conversation.group_uid, + "previous_owner_address": actor_address, + "owner_address": replacement.member_address, + "member_address": replacement.member_address, + }, + created_at=now, + ) + leave_event = await self._serialize_event( + session, + conversation=conversation, + event_type="leave", + actor_address=actor_address, + target_address=actor_address, + payload={ + "conversation_id": conversation.id, + "group_uid": conversation.group_uid, + "member_address": actor_address, + "reason": leave_reason, + }, + created_at=now, + ) + await self._broadcast_events(session, conversation, [transfer_event, leave_event]) + + await session.commit() + await metrics.inc("groups.owner_transferred") + await metrics.inc("groups.left") + return member + + synthetic = ConversationMember( + id=member.id, + conversation_id=member.conversation_id, + member_user_id=member.member_user_id, + member_address=member.member_address, + member_server_onion=member.member_server_onion, + role="member", + status="left", + invited_by_address=member.invited_by_address, + invited_at=member.invited_at, + joined_at=member.joined_at, + left_at=now, + updated_at=now, + ) + await session.delete(conversation) + await session.commit() + await metrics.inc("groups.deleted_on_owner_leave") + await metrics.inc("groups.left") + return synthetic + + member.status = "left" + member.left_at = now + member.updated_at = now + if conversation.origin_server_onion == get_server_onion(): + event = await self._serialize_event( + session, + conversation=conversation, + event_type="leave", + actor_address=actor_address, + target_address=actor_address, + payload={ + "conversation_id": conversation.id, + "group_uid": conversation.group_uid, + "member_address": actor_address, + "reason": leave_reason, + }, + created_at=now, + ) + await self._broadcast_events(session, conversation, [event]) + await session.commit() + await metrics.inc("groups.left") + return member + + async def rename(self, session: AsyncSession, *, conversation: Conversation, owner: User, name: str) -> Conversation: + await self.ensure_owner(session, conversation, owner.id) + conversation.group_name = self._normalize_name(name) + now = datetime.now(UTC) + actor_address = server_address_for_username(owner.username) + event = await self._serialize_event( + session, + conversation=conversation, + event_type="rename", + actor_address=actor_address, + target_address="", + payload={"conversation_id": conversation.id, "group_uid": conversation.group_uid, "group_name": conversation.group_name}, + created_at=now, + ) + await self._broadcast_events(session, conversation, [event]) + await session.commit() + await self._notify_group_renamed_local( + session, + conversation=conversation, + actor_address=actor_address, + event_seq=event.event_seq, + ) + await metrics.inc("groups.rename") + return conversation + + async def accept_invite(self, session: AsyncSession, *, conversation: Conversation, user: User) -> ConversationMember: + member = await self.ensure_member_access(session, conversation, user.id, require_active=False) + actor_address = server_address_for_username(user.username) + now = datetime.now(UTC) + member.status = "active" + member.joined_at = now + member.left_at = None + member.updated_at = now + if conversation.origin_server_onion == get_server_onion(): + event = await self._serialize_event( + session, + conversation=conversation, + event_type="accept", + actor_address=actor_address, + target_address=actor_address, + payload={"conversation_id": conversation.id, "group_uid": conversation.group_uid, "member_address": actor_address}, + created_at=now, + ) + await self._broadcast_events(session, conversation, [event]) + else: + try: + await federation_client.post_signed( + conversation.origin_server_onion, + "/api/v2/federation/groups/invites/accept", + { + "relay_id": str(uuid4()), + "group_uid": conversation.group_uid, + "conversation_id": conversation.id, + "actor_address": actor_address, + }, + ) + except FederationClientError as exc: + raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=exc.detail) from exc + await session.commit() + await metrics.inc("groups.accepted") + return member + + async def members_for_recipients( + self, + session: AsyncSession, + *, + conversation: Conversation, + user: User, + current_device_uid: str | None = None, + ) -> ConversationRecipientsOutV2: + await self.ensure_member_access(session, conversation, user.id, require_active=True) + members = await self._members_for_conversation(session, conversation.id) + recipients = [] + for member in members: + if member.status != "active": + continue + lookup = await device_service_v2.resolve_devices_by_peer_address(session, member.member_address) + if lookup is None: + continue + for device in lookup.devices: + if member.member_user_id == user.id and current_device_uid and device.device_uid == current_device_uid: + continue + recipients.append( + { + "member_address": member.member_address, + "member_status": "active", + "device": DeviceOutV2.model_validate(device), + "prekey": None, + } + ) + return ConversationRecipientsOutV2(conversation_id=conversation.id, conversation_type="group", recipients=recipients) + + async def to_conversation_out(self, session: AsyncSession, *, conversation: Conversation, user: User) -> ConversationOutV2: + members = await self._members_for_conversation(session, conversation.id) + member = next((row for row in members if row.member_user_id == user.id), None) + member_count_keys: set[str] = set() + for row in members: + if row.status not in {"active", "invited"}: + continue + member_count_keys.add(self._member_identity_key(row.member_user_id, row.member_address)) + return ConversationOutV2( + id=conversation.id, + kind=conversation.kind, + user_a_id=conversation.user_a_id or "", + user_b_id=conversation.user_b_id or "", + local_user_id=conversation.local_user_id or "", + created_at=conversation.created_at, + peer_username=conversation.peer_username, + peer_server_onion=conversation.peer_server_onion, + peer_address=conversation.peer_address, + conversation_type="group", + group_uid=conversation.group_uid, + group_name=conversation.group_name, + member_count=len(member_count_keys), + membership_state=("none" if member is None else member.status), # type: ignore[arg-type] + can_manage_members=bool(member is not None and member.role == "owner" and member.status == "active"), + origin_server_onion=conversation.origin_server_onion, + owner_address=conversation.owner_address, + ) + + async def get_group_by_uid(self, session: AsyncSession, group_uid: str) -> Conversation | None: + stmt = select(Conversation).where(Conversation.group_uid == group_uid, Conversation.conversation_type == "group") + return (await session.execute(stmt)).scalar_one_or_none() + + async def _apply_snapshot(self, session: AsyncSession, snapshot: FederationGroupSnapshotOutV2) -> Conversation: + conversation = await self.get_group_by_uid(session, snapshot.group_uid) + if conversation is None: + conversation = Conversation( + kind="group_remote", + conversation_type="group", + group_uid=snapshot.group_uid, + group_name=snapshot.group_name, + origin_server_onion=snapshot.origin_server_onion, + owner_address=snapshot.owner_address, + ) + session.add(conversation) + await session.flush() + else: + conversation.group_name = snapshot.group_name + conversation.origin_server_onion = snapshot.origin_server_onion + conversation.owner_address = snapshot.owner_address + + await session.execute(delete(ConversationMember).where(ConversationMember.conversation_id == conversation.id)) + for member_payload in snapshot.members: + member_address = str(member_payload.get("member_address", "")).strip().lower() + if not member_address: + continue + parsed = parse_peer_address_with_policy(member_address, self.settings.tor_enabled) + member_user_id = member_payload.get("member_user_id") + if is_local_server_authority(parsed.server_onion, self.settings): + user_stmt = select(User).where(User.username == parsed.username, User.disabled_at.is_(None)) + local_user = (await session.execute(user_stmt)).scalar_one_or_none() + member_user_id = local_user.id if local_user is not None else None + invited_at_raw = str(member_payload.get("invited_at", datetime.now(UTC).isoformat())) + invited_at = datetime.fromisoformat(invited_at_raw.replace("Z", "+00:00")) + joined_at_raw = member_payload.get("joined_at") + left_at_raw = member_payload.get("left_at") + updated_at_raw = str(member_payload.get("updated_at", invited_at_raw)) + session.add( + ConversationMember( + conversation_id=conversation.id, + member_user_id=member_user_id, + member_address=parsed.canonical, + member_server_onion=parsed.server_onion, + role=str(member_payload.get("role", "member")), + status=str(member_payload.get("status", "invited")), + invited_by_address=str(member_payload.get("invited_by_address", "")), + invited_at=invited_at, + joined_at=None if not joined_at_raw else datetime.fromisoformat(str(joined_at_raw).replace("Z", "+00:00")), + left_at=None if not left_at_raw else datetime.fromisoformat(str(left_at_raw).replace("Z", "+00:00")), + updated_at=datetime.fromisoformat(updated_at_raw.replace("Z", "+00:00")), + ) + ) + await session.execute(delete(GroupMembershipEvent).where(GroupMembershipEvent.group_uid == snapshot.group_uid)) + if snapshot.latest_event_seq > 0: + session.add( + GroupMembershipEvent( + group_uid=snapshot.group_uid, + event_seq=snapshot.latest_event_seq, + event_type="snapshot", + actor_address=snapshot.owner_address, + target_address="", + payload_json={"source": "federation_snapshot"}, + created_at=datetime.now(UTC), + ) + ) + await session.flush() + return conversation + + async def snapshot_for_group(self, session: AsyncSession, group_uid: str) -> FederationGroupSnapshotOutV2: + conversation = await self.get_group_by_uid(session, group_uid) + if conversation is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Group not found") + members = await self._members_for_conversation(session, conversation.id) + latest_stmt = select(func.max(GroupMembershipEvent.event_seq)).where(GroupMembershipEvent.group_uid == group_uid) + latest = int((await session.execute(latest_stmt)).scalar_one_or_none() or 0) + payload_members = [ + { + "member_user_id": row.member_user_id, + "member_address": row.member_address, + "member_server_onion": row.member_server_onion, + "role": row.role, + "status": row.status, + "invited_by_address": row.invited_by_address, + "invited_at": row.invited_at.isoformat(), + "joined_at": row.joined_at.isoformat() if row.joined_at else None, + "left_at": row.left_at.isoformat() if row.left_at else None, + "updated_at": row.updated_at.isoformat(), + } + for row in members + ] + return FederationGroupSnapshotOutV2( + group_uid=group_uid, + conversation_id=conversation.id, + group_name=conversation.group_name, + origin_server_onion=conversation.origin_server_onion, + owner_address=conversation.owner_address, + latest_event_seq=latest, + members=payload_members, + ) + + async def apply_federation_event( + self, + session: AsyncSession, + *, + payload: FederationGroupEventRequestV2, + sender_onion: str | None = None, + ) -> None: + current_seq = await self._current_event_seq(session, payload.group_uid) + if payload.event_seq <= current_seq: + if await self._event_exists(session, payload.group_uid, payload.event_seq): + return + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Group event sequence conflict") + expected_seq = current_seq + 1 + if payload.event_seq > expected_seq: + if not sender_onion: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Group event sequence gap") + try: + snapshot = await federation_client.get_remote_group_snapshot_v2(sender_onion, payload.group_uid) + except FederationClientError as exc: + raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=exc.detail) from exc + await self._apply_snapshot(session, snapshot) + current_seq = await self._current_event_seq(session, payload.group_uid) + expected_seq = current_seq + 1 + if payload.event_seq <= current_seq: + if await self._event_exists(session, payload.group_uid, payload.event_seq): + return + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Group event sequence conflict") + if payload.event_seq > expected_seq: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Group event sequence gap") + + conversation = await self.get_group_by_uid(session, payload.group_uid) + if conversation is None and payload.event_type != "create": + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Group not found") + if conversation is None: + conversation = Conversation( + kind="group_remote", + conversation_type="group", + group_uid=payload.group_uid, + group_name=str(payload.payload.get("group_name", "")), + origin_server_onion=str(payload.payload.get("origin_server_onion", "")), + owner_address=str(payload.payload.get("owner_address", payload.actor_address)), + ) + session.add(conversation) + await session.flush() + created_at = datetime.fromisoformat(payload.created_at.replace("Z", "+00:00")) + if payload.event_type == "create": + conversation.group_name = str(payload.payload.get("group_name", conversation.group_name)) + conversation.origin_server_onion = str(payload.payload.get("origin_server_onion", conversation.origin_server_onion)) + conversation.owner_address = str(payload.payload.get("owner_address", conversation.owner_address or payload.actor_address)) + if payload.event_type == "rename": + conversation.group_name = str(payload.payload.get("group_name", conversation.group_name)) + elif payload.event_type == "transfer_owner": + prior_owner_raw = ( + str(payload.payload.get("previous_owner_address", payload.actor_address or "")).strip().lower() + ) + next_owner_raw = str( + payload.payload.get("owner_address") + or payload.target_address + or payload.payload.get("member_address") + or "" + ).strip().lower() + if next_owner_raw: + parsed_next_owner = parse_peer_address_with_policy(next_owner_raw, self.settings.tor_enabled) + conversation.owner_address = parsed_next_owner.canonical + + prior_owner_canonical = "" + if prior_owner_raw: + try: + prior_owner_canonical = parse_peer_address_with_policy( + prior_owner_raw, self.settings.tor_enabled + ).canonical + except ValueError: + prior_owner_canonical = "" + + if prior_owner_canonical: + prior_owner_stmt = select(ConversationMember).where( + ConversationMember.conversation_id == conversation.id, + ConversationMember.member_address == prior_owner_canonical, + ) + prior_owner_row = (await session.execute(prior_owner_stmt)).scalar_one_or_none() + if prior_owner_row is not None: + prior_owner_row.role = "member" + prior_owner_row.updated_at = created_at + + next_owner_stmt = select(ConversationMember).where( + ConversationMember.conversation_id == conversation.id, + ConversationMember.member_address == parsed_next_owner.canonical, + ) + next_owner_row = (await session.execute(next_owner_stmt)).scalar_one_or_none() + if next_owner_row is None: + next_owner_row = ConversationMember( + conversation_id=conversation.id, + member_user_id=None, + member_address=parsed_next_owner.canonical, + member_server_onion=parsed_next_owner.server_onion, + role="owner", + status="active", + invited_by_address=payload.actor_address, + invited_at=created_at, + joined_at=created_at, + updated_at=created_at, + ) + session.add(next_owner_row) + else: + next_owner_row.role = "owner" + next_owner_row.status = "active" + if next_owner_row.joined_at is None: + next_owner_row.joined_at = created_at + next_owner_row.left_at = None + next_owner_row.updated_at = created_at + elif payload.event_type in {"invite", "accept", "remove", "leave", "create"}: + target = payload.target_address or payload.actor_address or conversation.owner_address + parsed = parse_peer_address_with_policy(target, self.settings.tor_enabled) + member_stmt = select(ConversationMember).where( + ConversationMember.conversation_id == conversation.id, + ConversationMember.member_address == parsed.canonical, + ) + row = (await session.execute(member_stmt)).scalar_one_or_none() + if row is None: + row = ConversationMember( + conversation_id=conversation.id, + member_user_id=None, + member_address=parsed.canonical, + member_server_onion=parsed.server_onion, + role="owner" if payload.event_type == "create" else "member", + status="active" if payload.event_type in {"create", "accept"} else "invited", + invited_by_address=payload.actor_address, + invited_at=created_at, + joined_at=created_at if payload.event_type in {"create", "accept"} else None, + updated_at=created_at, + ) + session.add(row) + else: + row.status = ( + "active" + if payload.event_type in {"accept", "create"} + else "removed" + if payload.event_type == "remove" + else "left" + if payload.event_type == "leave" + else "invited" + ) + if payload.event_type == "create": + row.role = "owner" + row.updated_at = created_at + if row.status == "active": + row.joined_at = created_at + row.left_at = None + if row.status in {"removed", "left"}: + row.left_at = created_at + await self._append_event( + session, + group_uid=payload.group_uid, + event_seq=payload.event_seq, + event_type=payload.event_type, + actor_address=payload.actor_address, + target_address=payload.target_address, + payload=payload.payload, + created_at=created_at, + ) + await session.commit() + if payload.event_type == "rename": + await self._notify_group_renamed_local( + session, + conversation=conversation, + actor_address=payload.actor_address, + event_seq=payload.event_seq, + ) + + async def accept_remote_invite_to_origin(self, session: AsyncSession, *, group_uid: str, actor_address: str) -> None: + conversation = await self.get_group_by_uid(session, group_uid) + if conversation is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Group not found") + stmt = select(ConversationMember).where( + ConversationMember.conversation_id == conversation.id, + ConversationMember.member_address == actor_address, + ) + row = (await session.execute(stmt)).scalar_one_or_none() + if row is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Member not found") + now = datetime.now(UTC) + row.status = "active" + row.joined_at = now + row.left_at = None + row.updated_at = now + event = await self._serialize_event( + session, + conversation=conversation, + event_type="accept", + actor_address=actor_address, + target_address=actor_address, + payload={"conversation_id": conversation.id, "group_uid": conversation.group_uid, "member_address": actor_address}, + created_at=now, + ) + await self._broadcast_events(session, conversation, [event]) + await session.commit() + + def member_out(self, row: ConversationMember) -> ConversationMemberOutV2: + return ConversationMemberOutV2( + id=row.id, + member_user_id=row.member_user_id, + member_address=row.member_address, + member_server_onion=row.member_server_onion, + role=row.role, # type: ignore[arg-type] + status=row.status, # type: ignore[arg-type] + invited_by_address=row.invited_by_address, + invited_at=row.invited_at, + joined_at=row.joined_at, + left_at=row.left_at, + updated_at=row.updated_at, + ) + + +group_conversation_service = GroupConversationService() diff --git a/server/app/services/message_service_v2.py b/server/app/services/message_service_v2.py new file mode 100644 index 0000000..81fc72e --- /dev/null +++ b/server/app/services/message_service_v2.py @@ -0,0 +1,933 @@ +import base64 +import hashlib +import json +from datetime import UTC, datetime, timedelta +from typing import Any +from uuid import uuid4 + +from fastapi import HTTPException, status +from nacl import encoding, signing +from nacl import exceptions as nacl_exceptions +from sqlalchemy import and_, or_, select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.config import get_settings +from app.models.conversation import Conversation +from app.models.conversation_member import ConversationMember +from app.models.device import Device +from app.models.message_device_copy import MessageDeviceCopy +from app.models.message_event import MessageEvent +from app.models.user import User +from app.schemas.v2_federation import FederationMessageRelayRequestV2 +from app.schemas.v2_message import MessageSendRequestV2, SignedEnvelopeV2 +from app.services.conversation_service import conversation_service +from app.services.device_service_v2 import device_service_v2 +from app.services.federation_outbox_service import federation_outbox_service +from app.services.group_conversation_service import group_conversation_service +from app.services.metrics import metrics +from app.services.peer_address import parse_peer_address_with_policy +from app.services.server_authority import is_local_server_authority +from app.services.server_identity import get_server_onion, server_address_for_username +from app.ws.manager import connection_manager + +SEALED_MODE = "sealedbox_v0_2a" +RATCHET_MODE = "ratchet_v0_2b1" + + +def _sha256_hex(data: bytes) -> str: + return hashlib.sha256(data).hexdigest() + + +def _decode_b64(value: str | None) -> bytes: + if value is None: + return b"" + return base64.b64decode(value.encode("utf-8"), validate=True) + + +def canonical_message_signature_string( + sender_address: str, + sender_device_uid: str, + recipient_user_address: str, + recipient_device_uid: str, + client_message_id: str, + sent_at_ms: int, + sender_prev_hash: str, + sender_chain_hash: str, + ciphertext_hash: str, + aad_hash: str, +) -> bytes: + canonical = "\n".join( + [ + sender_address, + sender_device_uid, + recipient_user_address, + recipient_device_uid, + client_message_id, + str(sent_at_ms), + sender_prev_hash, + sender_chain_hash, + ciphertext_hash, + aad_hash, + ] + ) + return canonical.encode("utf-8") + + +class MessageServiceV2: + def __init__(self) -> None: + self.settings = get_settings() + + def _validate_envelope_limits(self, envelope: SignedEnvelopeV2) -> None: + ciphertext_bytes = _decode_b64(envelope.ciphertext_b64) + if len(ciphertext_bytes) > self.settings.max_ciphertext_bytes: + raise HTTPException( + status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, + detail="Ciphertext is too large", + ) + + aad_bytes = _decode_b64(envelope.aad_b64) + if len(aad_bytes) > self.settings.max_aad_bytes: + raise HTTPException( + status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, + detail="AAD is too large", + ) + + @staticmethod + def _envelope_storage_bytes(envelope_json: dict[str, Any]) -> int: + return len(json.dumps(envelope_json, separators=(",", ":"), sort_keys=True).encode("utf-8")) + + async def _pending_queue_usage(self, session: AsyncSession, recipient_device_uid: str) -> tuple[int, int]: + stmt = select(MessageDeviceCopy).where( + MessageDeviceCopy.recipient_device_uid == recipient_device_uid, + MessageDeviceCopy.status == "pending", + ) + rows = list((await session.execute(stmt)).scalars().all()) + pending_bytes = 0 + for row in rows: + pending_bytes += self._envelope_storage_bytes(row.envelope_json) + return len(rows), pending_bytes + + async def _enforce_pending_queue_limits( + self, + session: AsyncSession, + copy_candidates: list[tuple[str, dict[str, Any]]], + ) -> None: + usage_by_device: dict[str, tuple[int, int]] = {} + projected_by_device: dict[str, tuple[int, int]] = {} + for recipient_device_uid, envelope_json in copy_candidates: + projected_count, projected_bytes = projected_by_device.get(recipient_device_uid, (0, 0)) + projected_by_device[recipient_device_uid] = ( + projected_count + 1, + projected_bytes + self._envelope_storage_bytes(envelope_json), + ) + + for recipient_device_uid, (new_count, new_bytes) in projected_by_device.items(): + if recipient_device_uid not in usage_by_device: + usage_by_device[recipient_device_uid] = await self._pending_queue_usage(session, recipient_device_uid) + existing_count, existing_bytes = usage_by_device[recipient_device_uid] + if existing_count + new_count > self.settings.pending_queue_max_copies_per_device: + await metrics.inc("attachments.send.rejected_queue_pressure") + raise HTTPException( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + detail="Recipient device pending queue copy limit exceeded", + ) + if existing_bytes + new_bytes > self.settings.pending_queue_max_bytes_per_device: + await metrics.inc("attachments.send.rejected_queue_pressure") + raise HTTPException( + status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, + detail="Recipient device pending queue byte limit exceeded", + ) + + @staticmethod + def _validate_encryption_mode(mode: str) -> str: + normalized = (mode or SEALED_MODE).strip().lower() + if normalized not in {SEALED_MODE, RATCHET_MODE}: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Unsupported encryption_mode") + return normalized + + def _enforce_encryption_policy(self, conversation_kind: str, encryption_mode: str) -> None: + if encryption_mode == RATCHET_MODE and not self.settings.enable_ratchet_v2b1: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="ratchet_v0_2b1 is disabled on this server", + ) + + if conversation_kind == "local" and self.settings.ratchet_require_for_local and encryption_mode != RATCHET_MODE: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="Ratchet mode is required for local delivery", + ) + if conversation_kind == "remote" and self.settings.ratchet_require_for_federation and encryption_mode != RATCHET_MODE: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="Ratchet mode is required for federation delivery", + ) + + @staticmethod + def _validate_mode_envelope_shape(encryption_mode: str, envelope: SignedEnvelopeV2) -> None: + if encryption_mode == RATCHET_MODE: + if envelope.ratchet_header is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="ratchet_header is required for ratchet_v0_2b1", + ) + else: + if envelope.ratchet_header is not None or envelope.ratchet_init is not None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="ratchet envelope fields are only allowed for ratchet_v0_2b1", + ) + + def _verify_signature( + self, + *, + payload: MessageSendRequestV2, + envelope: SignedEnvelopeV2, + sender_address: str, + sender_device_uid: str, + ) -> tuple[str, str]: + self._validate_envelope_limits(envelope) + ciphertext_bytes = _decode_b64(envelope.ciphertext_b64) + aad_bytes = _decode_b64(envelope.aad_b64) + ciphertext_hash = _sha256_hex(ciphertext_bytes) + aad_hash = _sha256_hex(aad_bytes) + + canonical = canonical_message_signature_string( + sender_address=sender_address, + sender_device_uid=sender_device_uid, + recipient_user_address=envelope.recipient_user_address, + recipient_device_uid=envelope.recipient_device_uid, + client_message_id=payload.client_message_id, + sent_at_ms=payload.sent_at_ms, + sender_prev_hash=payload.sender_prev_hash, + sender_chain_hash=payload.sender_chain_hash, + ciphertext_hash=ciphertext_hash, + aad_hash=aad_hash, + ) + try: + verify_key = signing.VerifyKey( + envelope.sender_device_pubkey.encode("utf-8"), + encoder=encoding.Base64Encoder, + ) + signature = encoding.Base64Encoder.decode(envelope.signature_b64.encode("utf-8")) + verify_key.verify(canonical, signature) + except (ValueError, nacl_exceptions.BadSignatureError) as exc: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid message signature") from exc + return ciphertext_hash, aad_hash + + @staticmethod + def _aggregate_hash(materials: list[str]) -> str: + joined = "|".join(materials).encode("utf-8") + return _sha256_hex(joined) + + def _compute_chain_hash( + self, + sender_prev_hash: str, + client_message_id: str, + sent_at_ms: int, + envelope_hashes: list[str], + ) -> str: + aggregate = self._aggregate_hash(sorted(envelope_hashes)) + chain_data = "\n".join( + [ + sender_prev_hash, + client_message_id, + str(sent_at_ms), + aggregate, + ] + ).encode("utf-8") + return _sha256_hex(chain_data) + + async def _expected_targets_for_conversation( + self, + session: AsyncSession, + conversation: Conversation, + sender: User, + sender_device_uid: str, + request_authority: str | None = None, + ) -> list[tuple[str, str, str | None, str | None]]: + sender_address = server_address_for_username(sender.username) + targets: list[tuple[str, str, str | None, str | None]] = [] + + sender_devices = await device_service_v2.list_active_for_user(session, sender.id) + for sender_device in sender_devices: + if sender_device.id == sender_device_uid: + continue + targets.append((sender_address, sender_device.id, sender.id, None)) + + if conversation.conversation_type == "group": + await group_conversation_service.ensure_member_access(session, conversation, sender.id, require_active=True) + member_rows = list( + ( + await session.execute( + select(ConversationMember).where( + ConversationMember.conversation_id == conversation.id, + ConversationMember.status == "active", + ) + ) + ).scalars() + ) + for member in member_rows: + if member.member_address == sender_address: + continue + lookup = await device_service_v2.resolve_devices_by_peer_address( + session, + member.member_address, + request_authority=request_authority, + ) + if lookup is None or not lookup.devices: + continue + parsed_remote = parse_peer_address_with_policy(lookup.peer_address, self.settings.tor_enabled) + additional_aliases = {request_authority} if request_authority else None + is_local = is_local_server_authority(parsed_remote.server_onion, self.settings, additional_aliases) + peer_onion = None if is_local else parsed_remote.server_onion + local_member_user_id = member.member_user_id + if is_local and local_member_user_id is None: + peer_stmt = select(User).where(User.username == parsed_remote.username, User.disabled_at.is_(None)) + peer_user = (await session.execute(peer_stmt)).scalar_one_or_none() + local_member_user_id = None if peer_user is None else peer_user.id + for device_out in lookup.devices: + targets.append( + ( + lookup.peer_address, + device_out.device_uid, + local_member_user_id if is_local else None, + peer_onion, + ) + ) + if not targets: + targets.append((sender_address, sender_device_uid, sender.id, None)) + return targets + + if conversation.kind == "local": + peer_user_id = conversation_service.peer_id(conversation, sender.id) + peer_stmt = select(User).where(User.id == peer_user_id, User.disabled_at.is_(None)) + peer_user = (await session.execute(peer_stmt)).scalar_one_or_none() + if peer_user is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Peer user not found") + peer_address = server_address_for_username(peer_user.username) + peer_devices = await device_service_v2.list_active_for_user(session, peer_user_id) + for peer_device in peer_devices: + targets.append((peer_address, peer_device.id, peer_user_id, None)) + if not targets: + targets.append((sender_address, sender_device_uid, sender.id, None)) + return targets + + if not conversation.peer_address or not conversation.peer_server_onion: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Remote conversation metadata missing") + + remote_resolution = await device_service_v2.resolve_devices_by_peer_address( + session, + conversation.peer_address, + request_authority=request_authority, + ) + if remote_resolution is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Remote peer devices not found") + try: + parsed_remote = parse_peer_address_with_policy(remote_resolution.peer_address, self.settings.tor_enabled) + except ValueError as exc: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc + additional_aliases = {request_authority} if request_authority else None + if is_local_server_authority(parsed_remote.server_onion, self.settings, additional_aliases): + peer_stmt = select(User).where(User.username == parsed_remote.username, User.disabled_at.is_(None)) + peer_user = (await session.execute(peer_stmt)).scalar_one_or_none() + if peer_user is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Peer user not found") + for local_device in remote_resolution.devices: + targets.append((remote_resolution.peer_address, local_device.device_uid, peer_user.id, None)) + if not targets: + targets.append((sender_address, sender_device_uid, sender.id, None)) + return targets + for remote_device in remote_resolution.devices: + targets.append((remote_resolution.peer_address, remote_device.device_uid, None, parsed_remote.server_onion)) + if not targets: + targets.append((sender_address, sender_device_uid, sender.id, None)) + return targets + + async def _canonical_conversation_for_sender( + self, + session: AsyncSession, + conversation: Conversation, + sender: User, + request_authority: str | None = None, + ) -> Conversation: + if conversation.kind != "remote" or not conversation.peer_address: + return conversation + try: + parsed = parse_peer_address_with_policy(conversation.peer_address, self.settings.tor_enabled) + except ValueError: + return conversation + additional_aliases = {request_authority} if request_authority else None + if not is_local_server_authority(parsed.server_onion, self.settings, additional_aliases): + return conversation + return await conversation_service.create_dm( + session, + sender, + parsed.username, + None, + request_authority=request_authority, + ) + + @staticmethod + def _event_payload(event: MessageEvent, copy: MessageDeviceCopy) -> dict[str, Any]: + return { + "type": "message.new", + "copy_id": copy.id, + "message": { + "id": event.id, + "conversation_id": event.conversation_id, + "sender_user_id": event.sender_user_id or "", + "sender_address": event.sender_address, + "sender_device_uid": event.sender_device_uid, + "sender_device_pubkey": event.sender_device_pubkey, + "encryption_mode": event.encryption_mode, + "client_message_id": event.client_message_id, + "sent_at_ms": event.sent_at_ms, + "sender_prev_hash": event.sender_prev_hash, + "sender_chain_hash": event.sender_chain_hash, + "envelope": copy.envelope_json, + "created_at": event.created_at.isoformat(), + }, + } + + async def _deliver_local_copies(self, session: AsyncSession, copies: list[MessageDeviceCopy], event: MessageEvent) -> None: + delivered_count = 0 + for copy in copies: + sent = await connection_manager.send_to_device(copy.recipient_device_uid, self._event_payload(event, copy)) + if sent > 0: + copy.status = "delivered" + copy.delivered_at = datetime.now(UTC) + copy.attempt_count += sent + copy.last_attempt_at = datetime.now(UTC) + delivered_count += sent + else: + copy.attempt_count += 1 + copy.last_attempt_at = datetime.now(UTC) + if copies: + await session.commit() + if delivered_count > 0: + await metrics.inc("messages.v2.delivered.realtime", delivered_count) + + async def _validate_chain( + self, + session: AsyncSession, + conversation_id: str, + sender_prev_hash: str, + sender_device_uid: str, + ) -> None: + stmt = ( + select(MessageEvent) + .where( + MessageEvent.conversation_id == conversation_id, + MessageEvent.sender_device_uid == sender_device_uid, + ) + .order_by(MessageEvent.created_at.desc()) + .limit(1) + ) + previous = (await session.execute(stmt)).scalar_one_or_none() + expected_prev = previous.sender_chain_hash if previous is not None else "" + if sender_prev_hash != expected_prev: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="sender_prev_hash does not match latest chain value", + ) + + async def _has_prior_sender_event( + self, + session: AsyncSession, + conversation_id: str, + sender_device_uid: str, + ) -> bool: + stmt = ( + select(MessageEvent.id) + .where( + MessageEvent.conversation_id == conversation_id, + MessageEvent.sender_device_uid == sender_device_uid, + ) + .limit(1) + ) + return (await session.execute(stmt)).scalar_one_or_none() is not None + + async def send_message( + self, + session: AsyncSession, + sender: User, + sender_device: Device, + payload: MessageSendRequestV2, + request_authority: str | None = None, + ) -> tuple[MessageEvent, bool]: + conversation = await conversation_service.get_by_id(session, payload.conversation_id) + if conversation is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Conversation not found") + if conversation.conversation_type == "group": + await group_conversation_service.ensure_member_access(session, conversation, sender.id, require_active=True) + else: + conversation_service.ensure_membership(conversation, sender.id) + conversation = await self._canonical_conversation_for_sender( + session, + conversation, + sender, + request_authority=request_authority, + ) + encryption_mode = self._validate_encryption_mode(payload.encryption_mode) + + duplicate_stmt = select(MessageEvent).where( + MessageEvent.sender_device_uid == sender_device.id, + MessageEvent.client_message_id == payload.client_message_id, + ) + duplicate = (await session.execute(duplicate_stmt)).scalar_one_or_none() + if duplicate is not None: + return duplicate, True + had_prior_event = await self._has_prior_sender_event(session, conversation.id, sender_device.id) + + sender_address = server_address_for_username(sender.username) + expected_targets = await self._expected_targets_for_conversation( + session, + conversation, + sender, + sender_device.id, + request_authority=request_authority, + ) + policy_scope = "remote" if conversation.kind == "remote" else "local" + if conversation.conversation_type == "group": + policy_scope = "remote" if any(peer_onion for _, _, _, peer_onion in expected_targets) else "local" + self._enforce_encryption_policy(policy_scope, encryption_mode) + if conversation.conversation_type == "group": + await metrics.inc("groups.messages.fanout_targets", len(expected_targets)) + expected_pairs = {(address, device_uid) for address, device_uid, _, _ in expected_targets} + envelope_pairs = {(env.recipient_user_address.strip().lower(), env.recipient_device_uid) for env in payload.envelopes} + if envelope_pairs != expected_pairs: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Envelope recipients do not match expected active device fanout", + ) + + await self._validate_chain(session, conversation.id, payload.sender_prev_hash, sender_device.id) + + envelope_hash_material: list[str] = [] + envelopes_by_target: dict[tuple[str, str], SignedEnvelopeV2] = {} + for envelope in payload.envelopes: + self._validate_mode_envelope_shape(encryption_mode, envelope) + if envelope.sender_device_pubkey != sender_device.ik_ed25519_pub: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="sender_device_pubkey must match sender registered key", + ) + ciphertext_hash, aad_hash = self._verify_signature( + payload=payload, + envelope=envelope, + sender_address=sender_address, + sender_device_uid=sender_device.id, + ) + envelope_hash_material.append(f"{envelope.recipient_device_uid}:{ciphertext_hash}:{aad_hash}") + key = (envelope.recipient_user_address.strip().lower(), envelope.recipient_device_uid) + envelopes_by_target[key] = envelope + + expected_chain_hash = self._compute_chain_hash( + payload.sender_prev_hash, + payload.client_message_id, + payload.sent_at_ms, + envelope_hash_material, + ) + if expected_chain_hash != payload.sender_chain_hash: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="sender_chain_hash is invalid") + + message_event = MessageEvent( + conversation_id=conversation.id, + sender_user_id=sender.id, + sender_address=sender_address, + sender_device_uid=sender_device.id, + sender_device_pubkey=sender_device.ik_ed25519_pub, + client_message_id=payload.client_message_id, + sent_at_ms=payload.sent_at_ms, + encryption_mode=encryption_mode, + sender_prev_hash=payload.sender_prev_hash, + sender_chain_hash=payload.sender_chain_hash, + ) + session.add(message_event) + await session.flush() + + local_copies: list[MessageDeviceCopy] = [] + local_copy_candidates: list[tuple[str, str, dict[str, Any]]] = [] + remote_envelopes_by_peer: dict[str, list[dict[str, Any]]] = {} + for address, device_uid, recipient_user_id, peer_onion in expected_targets: + envelope = envelopes_by_target[(address, device_uid)] + if recipient_user_id is None: + if not peer_onion: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Remote target missing peer onion") + remote_envelopes_by_peer.setdefault(peer_onion, []).append( + { + "recipient_user_address": address, + "recipient_device_uid": device_uid, + "ciphertext_b64": envelope.ciphertext_b64, + "aad_b64": envelope.aad_b64, + "signature_b64": envelope.signature_b64, + "sender_device_pubkey": envelope.sender_device_pubkey, + "ratchet_header": None if envelope.ratchet_header is None else envelope.ratchet_header.model_dump(), + "ratchet_init": None if envelope.ratchet_init is None else envelope.ratchet_init.model_dump(), + "recipient_user_id": "", + } + ) + continue + + envelope_json = { + "recipient_user_address": address, + "recipient_device_uid": device_uid, + "ciphertext_b64": envelope.ciphertext_b64, + "aad_b64": envelope.aad_b64, + "signature_b64": envelope.signature_b64, + "sender_device_pubkey": envelope.sender_device_pubkey, + "ratchet_header": None if envelope.ratchet_header is None else envelope.ratchet_header.model_dump(), + "ratchet_init": None if envelope.ratchet_init is None else envelope.ratchet_init.model_dump(), + } + local_copy_candidates.append( + ( + recipient_user_id, + device_uid, + envelope_json, + ) + ) + + await self._enforce_pending_queue_limits( + session, + [(device_uid, envelope_json) for _, device_uid, envelope_json in local_copy_candidates], + ) + + for recipient_user_id, recipient_device_uid, envelope_json in local_copy_candidates: + copy = MessageDeviceCopy( + message_event_id=message_event.id, + recipient_user_id=recipient_user_id, + recipient_device_uid=recipient_device_uid, + envelope_json=envelope_json, + status="pending", + expires_at=datetime.now(UTC) + timedelta(days=self.settings.message_ttl_days), + ) + session.add(copy) + local_copies.append(copy) + + await session.commit() + await session.refresh(message_event) + await metrics.inc("messages.v2.sent") + if encryption_mode == RATCHET_MODE: + await metrics.inc("messages.v2b1.mode.ratchet.sent") + if not had_prior_event: + await metrics.inc("ratchet.session.established") + elif self.settings.enable_ratchet_v2b1: + await metrics.inc("messages.v2b1.mode.sealedbox.fallback") + + if local_copies: + await self._deliver_local_copies(session, local_copies, message_event) + + if remote_envelopes_by_peer: + relay_payload = { + "relay_id": str(uuid4()), + "conversation_id": conversation.id, + "sender_address": sender_address, + "sender_device_uid": sender_device.id, + "sender_user_id": sender.id, + "encryption_mode": encryption_mode, + "client_message_id": payload.client_message_id, + "sent_at_ms": payload.sent_at_ms, + "sender_prev_hash": payload.sender_prev_hash, + "sender_chain_hash": payload.sender_chain_hash, + } + if conversation.conversation_type == "group" and conversation.group_uid: + relay_payload["group_uid"] = conversation.group_uid + for peer_onion, remote_envelopes in remote_envelopes_by_peer.items(): + payload_for_peer = dict(relay_payload) + payload_for_peer["envelopes"] = remote_envelopes + dedupe_key = f"v2-message:{sender_device.id}:{payload.client_message_id}:{conversation.id}:{peer_onion}" + outbox_item = await federation_outbox_service.enqueue( + session, + peer_onion=peer_onion, + event_type="message.v2.relay", + endpoint_path="/api/v2/federation/messages/relay", + payload_json=payload_for_peer, + dedupe_key=dedupe_key, + ) + await session.commit() + await federation_outbox_service.deliver_item(session, outbox_item.id) + if conversation.conversation_type == "group": + await metrics.inc("groups.messages.relay_servers", len(remote_envelopes_by_peer)) + + return message_event, False + + async def list_messages_for_device( + self, + session: AsyncSession, + user: User, + device_uid: str, + conversation_id: str, + limit: int, + offset: int, + ) -> list[tuple[MessageDeviceCopy, MessageEvent]]: + conversation = await conversation_service.get_by_id(session, conversation_id) + if conversation is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Conversation not found") + member = None + if conversation.conversation_type == "group": + member = await group_conversation_service.ensure_member_access( + session, + conversation, + user.id, + require_active=True, + ) + else: + conversation_service.ensure_membership(conversation, user.id) + + stmt = ( + select(MessageDeviceCopy, MessageEvent) + .join(MessageEvent, MessageEvent.id == MessageDeviceCopy.message_event_id) + .where( + MessageDeviceCopy.recipient_user_id == user.id, + MessageDeviceCopy.recipient_device_uid == device_uid, + MessageEvent.conversation_id == conversation_id, + ) + .order_by(MessageEvent.created_at.desc()) + .offset(offset) + .limit(limit) + ) + rows = list((await session.execute(stmt)).all()) + if conversation.conversation_type == "group" and member is not None and member.joined_at is not None: + rows = [row for row in rows if row[1].created_at >= member.joined_at] + return rows + + async def drain_pending_for_websocket( + self, + session: AsyncSession, + user_id: str, + device_uid: str, + websocket: Any, + ) -> None: + await self.expire_old(session) + now = datetime.now(UTC) + stmt = ( + select(MessageDeviceCopy, MessageEvent) + .join(MessageEvent, MessageEvent.id == MessageDeviceCopy.message_event_id) + .where( + MessageDeviceCopy.recipient_user_id == user_id, + MessageDeviceCopy.recipient_device_uid == device_uid, + MessageDeviceCopy.status == "pending", + MessageDeviceCopy.available_at <= now, + MessageDeviceCopy.expires_at > now, + ) + .order_by(MessageDeviceCopy.available_at.asc()) + ) + rows = list((await session.execute(stmt)).all()) + for copy, event in rows: + await websocket.send_json(self._event_payload(event, copy)) + copy.attempt_count += 1 + copy.last_attempt_at = datetime.now(UTC) + if rows: + await session.commit() + + async def acknowledge(self, session: AsyncSession, user_id: str, device_uid: str, copy_id: str) -> bool: + stmt = select(MessageDeviceCopy).where( + MessageDeviceCopy.id == copy_id, + MessageDeviceCopy.recipient_user_id == user_id, + MessageDeviceCopy.recipient_device_uid == device_uid, + MessageDeviceCopy.status == "pending", + ) + copy = (await session.execute(stmt)).scalar_one_or_none() + if copy is None: + return False + copy.status = "delivered" + copy.delivered_at = datetime.now(UTC) + await session.commit() + await metrics.inc("messages.v2.acknowledged") + return True + + async def expire_old(self, session: AsyncSession) -> int: + now = datetime.now(UTC) + stmt = select(MessageDeviceCopy).where( + MessageDeviceCopy.status == "pending", + MessageDeviceCopy.expires_at <= now, + ) + rows = list((await session.execute(stmt)).scalars().all()) + for row in rows: + row.status = "expired" + if rows: + await session.commit() + await metrics.inc("messages.v2.expired", len(rows)) + return len(rows) + + async def relay_message_from_federation( + self, + session: AsyncSession, + payload: FederationMessageRelayRequestV2, + ) -> tuple[MessageEvent, bool]: + if not payload.envelopes: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Relay payload has no envelopes") + encryption_mode = self._validate_encryption_mode(payload.encryption_mode) + self._enforce_encryption_policy("remote", encryption_mode) + + duplicate_stmt = select(MessageEvent).where( + MessageEvent.sender_device_uid == payload.sender_device_uid, + MessageEvent.client_message_id == payload.client_message_id, + ) + duplicate = (await session.execute(duplicate_stmt)).scalar_one_or_none() + if duplicate is not None: + return duplicate, True + # Ratchet sessions are tracked by (conversation, sender_device_uid) lifecycle. + had_prior_event = False + + first_address = payload.envelopes[0].recipient_user_address.strip().lower() + first_parsed = parse_peer_address_with_policy(first_address, self.settings.tor_enabled) + if first_parsed.server_onion != get_server_onion(): + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Recipient server mismatch") + + conversation: Conversation + group_member_by_address: dict[str, ConversationMember] = {} + if payload.group_uid: + group_conversation = await group_conversation_service.get_group_by_uid(session, payload.group_uid) + if group_conversation is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Group conversation not found") + conversation = group_conversation + members = list( + ( + await session.execute( + select(ConversationMember).where( + ConversationMember.conversation_id == conversation.id, + ConversationMember.status == "active", + ) + ) + ).scalars() + ) + group_member_by_address = {row.member_address: row for row in members} + else: + user_stmt = select(User).where(User.username == first_parsed.username, User.disabled_at.is_(None)) + recipient_user = (await session.execute(user_stmt)).scalar_one_or_none() + if recipient_user is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Recipient user not found") + conversation = await conversation_service.get_or_create_remote_for_local_user( + session, + recipient_user, + payload.sender_address, + ) + had_prior_event = await self._has_prior_sender_event(session, conversation.id, payload.sender_device_uid) + await self._validate_chain( + session, + conversation.id, + payload.sender_prev_hash, + payload.sender_device_uid, + ) + + envelope_hash_material: list[str] = [] + local_copies: list[MessageDeviceCopy] = [] + local_copy_candidates: list[tuple[str, str, dict[str, Any]]] = [] + for envelope in payload.envelopes: + parsed = parse_peer_address_with_policy(envelope.recipient_user_address, self.settings.tor_enabled) + if parsed.server_onion != get_server_onion(): + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid relay recipient") + recipient_user_id = "" + if payload.group_uid: + member = group_member_by_address.get(parsed.canonical) + if member is None or member.member_user_id is None: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid group relay recipient") + recipient_user_id = member.member_user_id + else: + user_stmt = select(User).where(User.username == parsed.username, User.disabled_at.is_(None)) + recipient_user = (await session.execute(user_stmt)).scalar_one_or_none() + if recipient_user is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Recipient user not found") + recipient_user_id = recipient_user.id + device_stmt = select(Device).where( + Device.id == envelope.recipient_device_uid, + Device.user_id == recipient_user_id, + Device.status == "active", + Device.revoked_at.is_(None), + ) + recipient_device = (await session.execute(device_stmt)).scalar_one_or_none() + if recipient_device is None: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Recipient device invalid") + + signed_env = SignedEnvelopeV2( + recipient_user_address=envelope.recipient_user_address, + recipient_device_uid=envelope.recipient_device_uid, + ciphertext_b64=envelope.ciphertext_b64, + aad_b64=envelope.aad_b64, + signature_b64=envelope.signature_b64, + sender_device_pubkey=envelope.sender_device_pubkey, + ratchet_header=envelope.ratchet_header, + ratchet_init=envelope.ratchet_init, + ) + self._validate_mode_envelope_shape(encryption_mode, signed_env) + ciphertext_hash, aad_hash = self._verify_signature( + payload=MessageSendRequestV2( + conversation_id=conversation.id, + encryption_mode=encryption_mode, + client_message_id=payload.client_message_id, + sent_at_ms=payload.sent_at_ms, + sender_prev_hash=payload.sender_prev_hash, + sender_chain_hash=payload.sender_chain_hash, + envelopes=[signed_env], + ), + envelope=signed_env, + sender_address=payload.sender_address, + sender_device_uid=payload.sender_device_uid, + ) + envelope_hash_material.append( + f"{envelope.recipient_device_uid}:{ciphertext_hash}:{aad_hash}" + ) + envelope_json = { + "recipient_user_address": envelope.recipient_user_address, + "recipient_device_uid": envelope.recipient_device_uid, + "ciphertext_b64": envelope.ciphertext_b64, + "aad_b64": envelope.aad_b64, + "signature_b64": envelope.signature_b64, + "sender_device_pubkey": envelope.sender_device_pubkey, + "ratchet_header": None if envelope.ratchet_header is None else envelope.ratchet_header.model_dump(), + "ratchet_init": None if envelope.ratchet_init is None else envelope.ratchet_init.model_dump(), + } + local_copy_candidates.append((recipient_user_id, envelope.recipient_device_uid, envelope_json)) + + expected_chain_hash = self._compute_chain_hash( + payload.sender_prev_hash, + payload.client_message_id, + payload.sent_at_ms, + envelope_hash_material, + ) + if expected_chain_hash != payload.sender_chain_hash: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="sender_chain_hash is invalid") + + event = MessageEvent( + conversation_id=conversation.id, + sender_user_id=None, + sender_address=payload.sender_address, + sender_device_uid=payload.sender_device_uid, + sender_device_pubkey=payload.envelopes[0].sender_device_pubkey, + client_message_id=payload.client_message_id, + sent_at_ms=payload.sent_at_ms, + encryption_mode=encryption_mode, + sender_prev_hash=payload.sender_prev_hash, + sender_chain_hash=payload.sender_chain_hash, + ) + session.add(event) + await session.flush() + await self._enforce_pending_queue_limits(session, [(device_uid, item) for _, device_uid, item in local_copy_candidates]) + + for recipient_user_id, recipient_device_uid, envelope_json in local_copy_candidates: + copy = MessageDeviceCopy( + recipient_user_id=recipient_user_id, + recipient_device_uid=recipient_device_uid, + envelope_json=envelope_json, + status="pending", + expires_at=datetime.now(UTC) + timedelta(days=self.settings.message_ttl_days), + message_event_id=event.id, + ) + session.add(copy) + local_copies.append(copy) + await session.commit() + await session.refresh(event) + + await self._deliver_local_copies(session, local_copies, event) + if encryption_mode == RATCHET_MODE: + await metrics.inc("messages.v2b1.mode.ratchet.recv") + if not had_prior_event: + await metrics.inc("ratchet.session.established") + elif self.settings.enable_ratchet_v2b1: + await metrics.inc("messages.v2b1.mode.sealedbox.fallback") + return event, False + + +message_service_v2 = MessageServiceV2() diff --git a/server/app/services/prekey_service_v2.py b/server/app/services/prekey_service_v2.py new file mode 100644 index 0000000..eae3519 --- /dev/null +++ b/server/app/services/prekey_service_v2.py @@ -0,0 +1,297 @@ +from datetime import UTC, datetime, timedelta + +from fastapi import HTTPException, status +from nacl import encoding, exceptions as nacl_exceptions, signing +from sqlalchemy import and_, select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.config import get_settings +from app.models.device import Device +from app.models.device_one_time_prekey import DeviceOneTimePrekey +from app.models.device_signed_prekey import DeviceSignedPrekey +from app.models.user import User +from app.schemas.v2_prekey import ( + OneTimePrekeyOutV2, + PrekeyUploadRequestV2, + PrekeyUploadResponseV2, + ResolvePrekeysResponseV2, + ResolvedPrekeyDeviceV2, + SignedPrekeyOutV2, +) +from app.services.device_service_v2 import device_service_v2 +from app.services.federation_client import FederationClientError, federation_client +from app.services.metrics import metrics +from app.services.peer_address import parse_peer_address_with_policy +from app.services.server_authority import is_local_server_authority +from app.services.server_identity import get_server_onion, server_address_for_username + + +def canonical_signed_prekey_string( + device_uid: str, + key_id: int, + pub_x25519_b64: str, + expires_at: datetime, +) -> bytes: + return "\n".join( + [ + "SIGNED_PREKEY", + device_uid.strip(), + str(key_id), + pub_x25519_b64.strip(), + expires_at.astimezone(UTC).isoformat(), + ] + ).encode("utf-8") + + +class PrekeyServiceV2: + def __init__(self) -> None: + self.settings = get_settings() + + async def upload_prekeys( + self, + session: AsyncSession, + user: User, + device: Device, + payload: PrekeyUploadRequestV2, + ) -> PrekeyUploadResponseV2: + expires_at = payload.signed_prekey.expires_at or (datetime.now(UTC) + timedelta(days=30)) + canonical = canonical_signed_prekey_string( + device_uid=device.id, + key_id=payload.signed_prekey.key_id, + pub_x25519_b64=payload.signed_prekey.pub_x25519_b64, + expires_at=expires_at, + ) + try: + verify_key = signing.VerifyKey(device.ik_ed25519_pub.encode("utf-8"), encoder=encoding.Base64Encoder) + signature = encoding.Base64Encoder.decode(payload.signed_prekey.sig_by_device_sign_key_b64.encode("utf-8")) + verify_key.verify(canonical, signature) + except (ValueError, nacl_exceptions.BadSignatureError) as exc: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid signed prekey signature") from exc + + signed_stmt = select(DeviceSignedPrekey).where( + DeviceSignedPrekey.device_id == device.id, + DeviceSignedPrekey.key_id == payload.signed_prekey.key_id, + ) + signed = (await session.execute(signed_stmt)).scalar_one_or_none() + if signed is None: + signed = DeviceSignedPrekey( + device_id=device.id, + key_id=payload.signed_prekey.key_id, + pub_x25519_b64=payload.signed_prekey.pub_x25519_b64, + sig_by_device_sign_key_b64=payload.signed_prekey.sig_by_device_sign_key_b64, + expires_at=expires_at, + revoked_at=None, + ) + session.add(signed) + else: + signed.pub_x25519_b64 = payload.signed_prekey.pub_x25519_b64 + signed.sig_by_device_sign_key_b64 = payload.signed_prekey.sig_by_device_sign_key_b64 + signed.expires_at = expires_at + signed.revoked_at = None + + key_ids = {item.key_id for item in payload.one_time_prekeys} + accepted = 0 + if key_ids: + existing_stmt = select(DeviceOneTimePrekey.key_id).where( + DeviceOneTimePrekey.device_id == device.id, + DeviceOneTimePrekey.key_id.in_(key_ids), + ) + existing = {row[0] for row in (await session.execute(existing_stmt)).all()} + for item in payload.one_time_prekeys: + if item.key_id in existing: + continue + session.add( + DeviceOneTimePrekey( + device_id=device.id, + key_id=item.key_id, + pub_x25519_b64=item.pub_x25519_b64, + ) + ) + accepted += 1 + + await session.commit() + return PrekeyUploadResponseV2( + uploaded_signed_prekey_key_id=payload.signed_prekey.key_id, + accepted_one_time_prekeys=accepted, + ) + + async def resolve_prekeys( + self, + session: AsyncSession, + requester_user: User, + peer_address: str, + request_authority: str | None = None, + ) -> ResolvePrekeysResponseV2: + parsed = parse_peer_address_with_policy(peer_address, self.settings.tor_enabled) + additional_aliases = {request_authority} if request_authority else None + if not is_local_server_authority(parsed.server_onion, self.settings, additional_aliases): + try: + return await federation_client.get_remote_user_prekeys_v2(parsed.server_onion, parsed.username) + except FederationClientError as exc: + if exc.status_code == 404: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Remote peer prekeys not found") from exc + raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=exc.detail) from exc + + lookup = await device_service_v2.resolve_devices_by_peer_address( + session, + parsed.canonical, + request_authority=request_authority, + ) + if lookup is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Peer devices not found") + + requester_address = server_address_for_username(requester_user.username) + out_devices: list[ResolvedPrekeyDeviceV2] = [] + opk_missing_count = 0 + now = datetime.now(UTC) + for device_out in lookup.devices: + signed_stmt = ( + select(DeviceSignedPrekey) + .where( + DeviceSignedPrekey.device_id == device_out.device_uid, + DeviceSignedPrekey.revoked_at.is_(None), + DeviceSignedPrekey.expires_at > now, + ) + .order_by(DeviceSignedPrekey.created_at.desc()) + .limit(1) + ) + signed = (await session.execute(signed_stmt)).scalar_one_or_none() + + opk_stmt = ( + select(DeviceOneTimePrekey) + .where( + DeviceOneTimePrekey.device_id == device_out.device_uid, + DeviceOneTimePrekey.consumed_at.is_(None), + ) + .order_by(DeviceOneTimePrekey.created_at.asc()) + .limit(1) + .with_for_update(skip_locked=True) + ) + opk = (await session.execute(opk_stmt)).scalar_one_or_none() + if opk is not None: + opk.consumed_at = now + opk.consumed_by_address = requester_address + else: + opk_missing_count += 1 + + out_devices.append( + ResolvedPrekeyDeviceV2( + device_uid=device_out.device_uid, + pub_sign_key=device_out.pub_sign_key, + pub_dh_key=device_out.pub_dh_key, + supported_message_modes=device_out.supported_message_modes, + signed_prekey=( + None + if signed is None + else SignedPrekeyOutV2( + key_id=signed.key_id, + pub_x25519_b64=signed.pub_x25519_b64, + sig_by_device_sign_key_b64=signed.sig_by_device_sign_key_b64, + expires_at=signed.expires_at, + ) + ), + one_time_prekey=( + None + if opk is None + else OneTimePrekeyOutV2( + key_id=opk.key_id, + pub_x25519_b64=opk.pub_x25519_b64, + ) + ), + opk_missing=opk is None, + ) + ) + + await session.commit() + if opk_missing_count > 0: + await metrics.inc("ratchet.prekey.opk.exhausted", opk_missing_count) + return ResolvePrekeysResponseV2( + username=lookup.username, + peer_address=lookup.peer_address, + devices=out_devices, + ) + + async def resolve_local_user_prekeys_for_federation( + self, + session: AsyncSession, + username: str, + requested_by_address: str, + ) -> ResolvePrekeysResponseV2 | None: + canonical = f"{username.strip().lower()}@{get_server_onion()}" + parsed = parse_peer_address_with_policy(canonical, self.settings.tor_enabled) + lookup = await device_service_v2.resolve_devices_by_peer_address(session, parsed.canonical) + if lookup is None: + return None + + out_devices: list[ResolvedPrekeyDeviceV2] = [] + opk_missing_count = 0 + now = datetime.now(UTC) + for device_out in lookup.devices: + signed_stmt = ( + select(DeviceSignedPrekey) + .where( + DeviceSignedPrekey.device_id == device_out.device_uid, + DeviceSignedPrekey.revoked_at.is_(None), + DeviceSignedPrekey.expires_at > now, + ) + .order_by(DeviceSignedPrekey.created_at.desc()) + .limit(1) + ) + signed = (await session.execute(signed_stmt)).scalar_one_or_none() + + opk_stmt = ( + select(DeviceOneTimePrekey) + .where( + DeviceOneTimePrekey.device_id == device_out.device_uid, + DeviceOneTimePrekey.consumed_at.is_(None), + ) + .order_by(DeviceOneTimePrekey.created_at.asc()) + .limit(1) + .with_for_update(skip_locked=True) + ) + opk = (await session.execute(opk_stmt)).scalar_one_or_none() + if opk is not None: + opk.consumed_at = now + opk.consumed_by_address = requested_by_address + else: + opk_missing_count += 1 + + out_devices.append( + ResolvedPrekeyDeviceV2( + device_uid=device_out.device_uid, + pub_sign_key=device_out.pub_sign_key, + pub_dh_key=device_out.pub_dh_key, + supported_message_modes=device_out.supported_message_modes, + signed_prekey=( + None + if signed is None + else SignedPrekeyOutV2( + key_id=signed.key_id, + pub_x25519_b64=signed.pub_x25519_b64, + sig_by_device_sign_key_b64=signed.sig_by_device_sign_key_b64, + expires_at=signed.expires_at, + ) + ), + one_time_prekey=( + None + if opk is None + else OneTimePrekeyOutV2( + key_id=opk.key_id, + pub_x25519_b64=opk.pub_x25519_b64, + ) + ), + opk_missing=opk is None, + ) + ) + + await session.commit() + if opk_missing_count > 0: + await metrics.inc("ratchet.prekey.opk.exhausted", opk_missing_count) + return ResolvePrekeysResponseV2( + username=lookup.username, + peer_address=lookup.peer_address, + devices=out_devices, + ) + + +prekey_service_v2 = PrekeyServiceV2() diff --git a/server/app/services/presence_service.py b/server/app/services/presence_service.py new file mode 100644 index 0000000..04f210f --- /dev/null +++ b/server/app/services/presence_service.py @@ -0,0 +1,77 @@ +import asyncio + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.user import User +from app.services.server_identity import get_server_onion +from app.ws.manager import connection_manager + + +class PresenceService: + def __init__(self) -> None: + self._lock = asyncio.Lock() + self._manual_status_by_user_id: dict[str, str] = {} + + async def set_status(self, user_id: str, status: str) -> str: + normalized = status.strip().lower() + async with self._lock: + self._manual_status_by_user_id[user_id] = normalized + return normalized + + async def effective_status(self, user_id: str) -> str: + async with self._lock: + manual = self._manual_status_by_user_id.get(user_id, "active") + if manual == "offline": + return "offline" + is_online = await connection_manager.has_user(user_id) + if not is_online: + return "offline" + return manual + + async def resolve_for_peer_addresses( + self, + session: AsyncSession, + peer_addresses: list[str], + ) -> list[tuple[str, str]]: + normalized = [address.strip().lower() for address in peer_addresses if address.strip()] + if not normalized: + return [] + + local_server = get_server_onion().strip().lower() + usernames: set[str] = set() + for peer_address in normalized: + if "@" not in peer_address: + continue + username, server_onion = peer_address.split("@", 1) + if not username or not server_onion: + continue + if server_onion != local_server: + continue + usernames.add(username) + + by_username: dict[str, User] = {} + if usernames: + stmt = select(User).where(User.username.in_(sorted(usernames)), User.disabled_at.is_(None)) + users = (await session.execute(stmt)).scalars().all() + by_username = {user.username.strip().lower(): user for user in users} + + resolved: list[tuple[str, str]] = [] + for peer_address in normalized: + if "@" not in peer_address: + resolved.append((peer_address, "offline")) + continue + username, server_onion = peer_address.split("@", 1) + if not username or not server_onion or server_onion != local_server: + resolved.append((peer_address, "offline")) + continue + user = by_username.get(username) + if user is None: + resolved.append((peer_address, "offline")) + continue + resolved.append((peer_address, await self.effective_status(user.id))) + + return resolved + + +presence_service = PresenceService() diff --git a/server/app/services/rate_limit.py b/server/app/services/rate_limit.py index b79401c..e75152a 100644 --- a/server/app/services/rate_limit.py +++ b/server/app/services/rate_limit.py @@ -31,22 +31,30 @@ async def close(self) -> None: async def enforce(self, key: str, limit: int | None = None) -> None: current_limit = limit or self.settings.rate_limit_per_minute - allowed = await self._allow(key, current_limit) + allowed = await self._allow(key, current_limit, units=1) if not allowed: raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Rate limit exceeded") - async def _allow(self, key: str, limit: int) -> bool: + async def enforce_weighted(self, key: str, units: int, limit: int) -> None: + effective_units = max(units, 0) + if effective_units == 0: + return + allowed = await self._allow(key, limit, units=effective_units) + if not allowed: + raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Rate limit exceeded") + + async def _allow(self, key: str, limit: int, units: int) -> bool: current_window = int(time.time() // 60) if self._redis_client is not None: redis_key = f"rl:{key}:{current_window}" - raw_count = await self._redis_client.incr(redis_key) + raw_count = await self._redis_client.incrby(redis_key, units) count = cast(int, raw_count) - if count == 1: + if count == units: await self._redis_client.expire(redis_key, 70) return count <= limit window_counts = self._memory_counts[key] - count = window_counts.get(current_window, 0) + 1 + count = window_counts.get(current_window, 0) + units window_counts[current_window] = count stale_keys = [window for window in window_counts if window < current_window - 1] diff --git a/server/app/services/server_authority.py b/server/app/services/server_authority.py new file mode 100644 index 0000000..ec72b8f --- /dev/null +++ b/server/app/services/server_authority.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from urllib.parse import urlsplit + +from app.config import Settings +from app.services.server_identity import get_server_onion + + +def normalize_server_authority(value: str) -> str: + normalized = value.strip().lower() + if normalized.startswith("http://"): + normalized = normalized[7:] + if normalized.startswith("https://"): + normalized = normalized[8:] + return normalized.rstrip("/") + + +def _host_port(authority: str) -> tuple[str, int | None]: + normalized = normalize_server_authority(authority) + if not normalized: + return "", None + if normalized.count(":") >= 2 and not normalized.startswith("["): + normalized = f"[{normalized}]" + parsed = urlsplit(f"http://{normalized}") + host = (parsed.hostname or normalized).strip().lower() + try: + port = parsed.port + except ValueError: + return host, None + return host, port + + +def authority_matches(left: str, right: str) -> bool: + left_norm = normalize_server_authority(left) + right_norm = normalize_server_authority(right) + if not left_norm or not right_norm: + return False + if left_norm == right_norm: + return True + left_host, left_port = _host_port(left_norm) + right_host, right_port = _host_port(right_norm) + if not left_host or not right_host: + return False + if left_host != right_host: + return False + return left_port is None or right_port is None or left_port == right_port + + +def local_server_authorities(settings: Settings) -> set[str]: + authorities = { + normalize_server_authority(get_server_onion()), + normalize_server_authority(settings.federation_server_onion), + "localhost", + "127.0.0.1", + "[::1]", + } + configured = settings.local_server_aliases.strip() + if configured: + for item in configured.split(","): + alias = normalize_server_authority(item) + if alias: + authorities.add(alias) + return {value for value in authorities if value} + + +def is_local_server_authority( + authority: str, + settings: Settings, + additional_aliases: set[str] | None = None, +) -> bool: + target = normalize_server_authority(authority) + if not target: + return False + candidate_authorities = local_server_authorities(settings) + if additional_aliases is not None: + for alias in additional_aliases: + normalized = normalize_server_authority(alias) + if normalized: + candidate_authorities.add(normalized) + for local_authority in candidate_authorities: + if authority_matches(target, local_authority): + return True + return False diff --git a/server/app/services/server_identity.py b/server/app/services/server_identity.py index 9e13230..a38d0de 100644 --- a/server/app/services/server_identity.py +++ b/server/app/services/server_identity.py @@ -40,6 +40,11 @@ def _is_onion_authority(value: str) -> bool: def _build_signing_key(settings: Settings, server_onion: str) -> signing.SigningKey: + if settings.tor_enabled: + tor_key = _load_tor_hidden_service_signing_key(settings, server_onion) + if tor_key is not None: + return tor_key + configured = settings.federation_signing_private_key_b64.strip() if configured: raw = base64.b64decode(configured.encode("utf-8"), validate=True) @@ -47,7 +52,7 @@ def _build_signing_key(settings: Settings, server_onion: str) -> signing.Signing if settings.tor_enabled: raise RuntimeError( - "BLACKWIRE_FEDERATION_SIGNING_PRIVATE_KEY_B64 is required when BLACKWIRE_TOR_ENABLED=true" + "Tor is enabled but federation signing key could not be derived from hidden-service key" ) seed = hashlib.sha256( @@ -56,6 +61,51 @@ def _build_signing_key(settings: Settings, server_onion: str) -> signing.Signing return signing.SigningKey(seed) +def _onion_v3_public_key(authority: str) -> bytes | None: + value = authority.strip().lower() + if value.endswith(".onion"): + value = value[:-6] + if len(value) != 56: + return None + try: + decoded = base64.b32decode(value.upper()) + except Exception: + return None + if len(decoded) != 35: + return None + return decoded[:32] + + +def _load_tor_hidden_service_signing_key(settings: Settings, server_onion: str) -> signing.SigningKey | None: + try: + raw = Path(settings.tor_hs_ed25519_secret_key_file).read_bytes() + except OSError: + return None + + onion_pub = _onion_v3_public_key(server_onion) + candidates: list[bytes] = [] + if len(raw) >= 64: + candidates.append(raw[-64:-32]) + candidates.append(raw[-32:]) + if len(raw) >= 96: + candidates.append(raw[-96:-64]) + if len(raw) >= 32: + candidates.append(raw[:32]) + + seen: set[bytes] = set() + for candidate in candidates: + if len(candidate) != 32 or candidate in seen: + continue + seen.add(candidate) + try: + key = signing.SigningKey(candidate) + except Exception: + continue + if onion_pub is None or key.verify_key.encode() == onion_pub: + return key + return None + + def _wait_for_hidden_service_hostname(path: str, wait_seconds: int) -> str | None: if wait_seconds <= 0: return _load_hidden_service_hostname(path) diff --git a/server/app/ws/manager.py b/server/app/ws/manager.py index 69e213b..4f4e268 100644 --- a/server/app/ws/manager.py +++ b/server/app/ws/manager.py @@ -1,27 +1,56 @@ import asyncio from collections import defaultdict +from dataclasses import dataclass from fastapi import WebSocket +@dataclass(frozen=True, slots=True) +class _ConnectionMeta: + user_id: str + device_uid: str | None + websocket: WebSocket + + class ConnectionManager: def __init__(self) -> None: self._connections: dict[str, set[WebSocket]] = defaultdict(set) + self._device_connections: dict[str, set[WebSocket]] = defaultdict(set) + self._reverse_index: dict[WebSocket, _ConnectionMeta] = {} self._lock = asyncio.Lock() - async def connect(self, user_id: str, websocket: WebSocket) -> None: + async def connect(self, user_id: str, websocket: WebSocket, device_uid: str | None = None) -> None: await websocket.accept() async with self._lock: self._connections[user_id].add(websocket) + if device_uid: + self._device_connections[device_uid].add(websocket) + self._reverse_index[websocket] = _ConnectionMeta( + user_id=user_id, + device_uid=device_uid, + websocket=websocket, + ) - async def disconnect(self, user_id: str, websocket: WebSocket) -> None: + async def disconnect(self, user_id: str, websocket: WebSocket, device_uid: str | None = None) -> None: async with self._lock: - sockets = self._connections.get(user_id) + meta = self._reverse_index.pop(websocket, None) + resolved_user_id = user_id or (meta.user_id if meta else "") + resolved_device_uid = device_uid or (meta.device_uid if meta else None) + + sockets = self._connections.get(resolved_user_id) if sockets is None: - return - sockets.discard(websocket) - if not sockets: - self._connections.pop(user_id, None) + pass + else: + sockets.discard(websocket) + if not sockets: + self._connections.pop(resolved_user_id, None) + + if resolved_device_uid: + device_sockets = self._device_connections.get(resolved_device_uid) + if device_sockets is not None: + device_sockets.discard(websocket) + if not device_sockets: + self._device_connections.pop(resolved_device_uid, None) async def send_to_user(self, user_id: str, payload: dict) -> int: async with self._lock: @@ -39,6 +68,38 @@ async def send_to_user(self, user_id: str, payload: dict) -> int: await self.disconnect(user_id, websocket) return sent + async def send_to_device(self, device_uid: str, payload: dict) -> int: + async with self._lock: + sockets = list(self._device_connections.get(device_uid, set())) + + if not sockets: + return 0 + + sent = 0 + for websocket in sockets: + try: + await websocket.send_json(payload) + sent += 1 + except Exception: + meta = self._reverse_index.get(websocket) + await self.disconnect(meta.user_id if meta else "", websocket, device_uid=device_uid) + return sent + + async def disconnect_device(self, device_uid: str) -> int: + async with self._lock: + sockets = list(self._device_connections.get(device_uid, set())) + + closed = 0 + for websocket in sockets: + meta = self._reverse_index.get(websocket) + try: + await websocket.close(code=1008, reason="Device revoked") + closed += 1 + except Exception: + pass + await self.disconnect(meta.user_id if meta else "", websocket, device_uid=device_uid) + return closed + async def has_user(self, user_id: str) -> bool: async with self._lock: return bool(self._connections.get(user_id)) diff --git a/server/migrations/versions/20260223_0004_v2_security_core.py b/server/migrations/versions/20260223_0004_v2_security_core.py new file mode 100644 index 0000000..2b8ca28 --- /dev/null +++ b/server/migrations/versions/20260223_0004_v2_security_core.py @@ -0,0 +1,155 @@ +"""v2 security core schema + +Revision ID: 20260223_0004 +Revises: 20260217_0003 +Create Date: 2026-02-23 15:20:00 +""" + +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa + + +revision: str = "20260223_0004" +down_revision: str | None = "20260217_0003" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + with op.batch_alter_table("devices") as batch: + batch.add_column(sa.Column("status", sa.String(length=16), nullable=False, server_default="active")) + batch.add_column( + sa.Column( + "last_seen_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.text("CURRENT_TIMESTAMP"), + ) + ) + op.create_index("ix_devices_status", "devices", ["status"], unique=False) + + with op.batch_alter_table("refresh_tokens") as batch: + batch.add_column(sa.Column("device_id", sa.String(length=36), nullable=True)) + batch.add_column(sa.Column("token_kind", sa.String(length=16), nullable=False, server_default="refresh")) + batch.create_foreign_key( + "refresh_tokens_device_id_fkey", + "devices", + ["device_id"], + ["id"], + ondelete="CASCADE", + ) + op.create_index("ix_refresh_tokens_device_id", "refresh_tokens", ["device_id"], unique=False) + op.create_index("ix_refresh_tokens_token_kind", "refresh_tokens", ["token_kind"], unique=False) + op.create_index( + "ix_refresh_tokens_user_device_revoked", + "refresh_tokens", + ["user_id", "device_id", "revoked_at"], + unique=False, + ) + + op.create_table( + "message_events", + sa.Column("id", sa.String(length=36), primary_key=True), + sa.Column( + "conversation_id", + sa.String(length=36), + sa.ForeignKey("conversations.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("sender_user_id", sa.String(length=36), sa.ForeignKey("users.id", ondelete="SET NULL"), nullable=True), + sa.Column("sender_address", sa.String(length=320), nullable=False), + sa.Column("sender_device_uid", sa.String(length=36), nullable=False), + sa.Column("sender_device_pubkey", sa.String(length=256), nullable=False), + sa.Column("client_message_id", sa.String(length=64), nullable=False), + sa.Column("sent_at_ms", sa.Integer(), nullable=False), + sa.Column("sender_prev_hash", sa.String(length=128), nullable=False, server_default=""), + sa.Column("sender_chain_hash", sa.String(length=128), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.UniqueConstraint( + "sender_device_uid", + "client_message_id", + name="uq_message_event_sender_client_message", + ), + ) + op.create_index("ix_message_events_conversation_id", "message_events", ["conversation_id"], unique=False) + op.create_index("ix_message_events_sender_user_id", "message_events", ["sender_user_id"], unique=False) + op.create_index("ix_message_events_sender_address", "message_events", ["sender_address"], unique=False) + op.create_index("ix_message_events_sender_device_uid", "message_events", ["sender_device_uid"], unique=False) + op.create_index("ix_message_events_client_message_id", "message_events", ["client_message_id"], unique=False) + op.create_index("ix_message_events_sender_chain_hash", "message_events", ["sender_chain_hash"], unique=False) + op.create_index("ix_message_events_sent_at_ms", "message_events", ["sent_at_ms"], unique=False) + + op.create_table( + "message_device_copies", + sa.Column("id", sa.String(length=36), primary_key=True), + sa.Column( + "message_event_id", + sa.String(length=36), + sa.ForeignKey("message_events.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column( + "recipient_user_id", + sa.String(length=36), + sa.ForeignKey("users.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("recipient_device_uid", sa.String(length=36), nullable=False), + sa.Column("envelope_json", sa.JSON(), nullable=False), + sa.Column("status", sa.String(length=20), nullable=False, server_default="pending"), + sa.Column("available_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("delivered_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("attempt_count", sa.Integer(), nullable=False, server_default="0"), + sa.Column("last_attempt_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.UniqueConstraint( + "message_event_id", + "recipient_device_uid", + name="uq_message_copy_event_recipient_device", + ), + ) + op.create_index("ix_message_device_copies_message_event_id", "message_device_copies", ["message_event_id"], unique=False) + op.create_index("ix_message_device_copies_recipient_user_id", "message_device_copies", ["recipient_user_id"], unique=False) + op.create_index( + "ix_message_device_copies_recipient_device_uid", + "message_device_copies", + ["recipient_device_uid"], + unique=False, + ) + op.create_index("ix_message_device_copies_status", "message_device_copies", ["status"], unique=False) + op.create_index("ix_message_device_copies_expires_at", "message_device_copies", ["expires_at"], unique=False) + + +def downgrade() -> None: + op.drop_index("ix_message_device_copies_expires_at", table_name="message_device_copies") + op.drop_index("ix_message_device_copies_status", table_name="message_device_copies") + op.drop_index("ix_message_device_copies_recipient_device_uid", table_name="message_device_copies") + op.drop_index("ix_message_device_copies_recipient_user_id", table_name="message_device_copies") + op.drop_index("ix_message_device_copies_message_event_id", table_name="message_device_copies") + op.drop_table("message_device_copies") + + op.drop_index("ix_message_events_sent_at_ms", table_name="message_events") + op.drop_index("ix_message_events_sender_chain_hash", table_name="message_events") + op.drop_index("ix_message_events_client_message_id", table_name="message_events") + op.drop_index("ix_message_events_sender_device_uid", table_name="message_events") + op.drop_index("ix_message_events_sender_address", table_name="message_events") + op.drop_index("ix_message_events_sender_user_id", table_name="message_events") + op.drop_index("ix_message_events_conversation_id", table_name="message_events") + op.drop_table("message_events") + + op.drop_index("ix_refresh_tokens_user_device_revoked", table_name="refresh_tokens") + op.drop_index("ix_refresh_tokens_token_kind", table_name="refresh_tokens") + op.drop_index("ix_refresh_tokens_device_id", table_name="refresh_tokens") + with op.batch_alter_table("refresh_tokens") as batch: + batch.drop_constraint("refresh_tokens_device_id_fkey", type_="foreignkey") + batch.drop_column("token_kind") + batch.drop_column("device_id") + + op.drop_index("ix_devices_status", table_name="devices") + with op.batch_alter_table("devices") as batch: + batch.drop_column("last_seen_at") + batch.drop_column("status") + diff --git a/server/migrations/versions/20260223_0005_message_event_sent_at_ms_bigint.py b/server/migrations/versions/20260223_0005_message_event_sent_at_ms_bigint.py new file mode 100644 index 0000000..885ce4d --- /dev/null +++ b/server/migrations/versions/20260223_0005_message_event_sent_at_ms_bigint.py @@ -0,0 +1,37 @@ +"""message_events.sent_at_ms bigint + +Revision ID: 20260223_0005 +Revises: 20260223_0004 +Create Date: 2026-02-23 16:20:00 +""" + +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa + + +revision: str = "20260223_0005" +down_revision: str | None = "20260223_0004" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.alter_column( + "message_events", + "sent_at_ms", + existing_type=sa.Integer(), + type_=sa.BigInteger(), + existing_nullable=False, + ) + + +def downgrade() -> None: + op.alter_column( + "message_events", + "sent_at_ms", + existing_type=sa.BigInteger(), + type_=sa.Integer(), + existing_nullable=False, + ) diff --git a/server/migrations/versions/20260223_0006_v2_ratchet_core.py b/server/migrations/versions/20260223_0006_v2_ratchet_core.py new file mode 100644 index 0000000..0e80113 --- /dev/null +++ b/server/migrations/versions/20260223_0006_v2_ratchet_core.py @@ -0,0 +1,139 @@ +"""v2 ratchet core schema + +Revision ID: 20260223_0006 +Revises: 20260223_0005 +Create Date: 2026-02-23 18:00:00 +""" + +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa + + +revision: str = "20260223_0006" +down_revision: str | None = "20260223_0005" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.create_table( + "device_signed_prekeys", + sa.Column("id", sa.String(length=36), primary_key=True), + sa.Column("device_id", sa.String(length=36), sa.ForeignKey("devices.id", ondelete="CASCADE"), nullable=False), + sa.Column("key_id", sa.Integer(), nullable=False), + sa.Column("pub_x25519_b64", sa.String(length=128), nullable=False), + sa.Column("sig_by_device_sign_key_b64", sa.String(length=256), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("revoked_at", sa.DateTime(timezone=True), nullable=True), + sa.UniqueConstraint("device_id", "key_id", name="uq_device_signed_prekey_device_key"), + ) + op.create_index("ix_device_signed_prekeys_device_id", "device_signed_prekeys", ["device_id"], unique=False) + op.create_index( + "ix_device_signed_prekeys_device_exp_revoked", + "device_signed_prekeys", + ["device_id", "expires_at", "revoked_at"], + unique=False, + ) + + op.create_table( + "device_one_time_prekeys", + sa.Column("id", sa.String(length=36), primary_key=True), + sa.Column("device_id", sa.String(length=36), sa.ForeignKey("devices.id", ondelete="CASCADE"), nullable=False), + sa.Column("key_id", sa.Integer(), nullable=False), + sa.Column("pub_x25519_b64", sa.String(length=128), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("consumed_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("consumed_by_address", sa.String(length=320), nullable=True), + sa.UniqueConstraint("device_id", "key_id", name="uq_device_one_time_prekey_device_key"), + ) + op.create_index("ix_device_one_time_prekeys_device_id", "device_one_time_prekeys", ["device_id"], unique=False) + op.create_index( + "ix_device_one_time_prekeys_device_consumed", + "device_one_time_prekeys", + ["device_id", "consumed_at"], + unique=False, + ) + + op.create_table( + "ratchet_sessions", + sa.Column("id", sa.String(length=36), primary_key=True), + sa.Column("owner_user_id", sa.String(length=36), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False), + sa.Column("owner_device_uid", sa.String(length=36), sa.ForeignKey("devices.id", ondelete="CASCADE"), nullable=False), + sa.Column("peer_address", sa.String(length=320), nullable=False), + sa.Column("peer_device_uid", sa.String(length=64), nullable=False), + sa.Column("session_version", sa.String(length=16), nullable=False, server_default="dr_v1"), + sa.Column("state_blob_encrypted_b64", sa.String(length=65536), nullable=False, server_default=""), + sa.Column("state_nonce_b64", sa.String(length=128), nullable=False, server_default=""), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("last_send_chain_n", sa.Integer(), nullable=False, server_default="0"), + sa.Column("last_recv_chain_n", sa.Integer(), nullable=False, server_default="0"), + sa.Column("last_root_key_hash", sa.String(length=128), nullable=False, server_default=""), + sa.UniqueConstraint( + "owner_device_uid", + "peer_address", + "peer_device_uid", + name="uq_ratchet_session_owner_peer", + ), + ) + op.create_index("ix_ratchet_sessions_owner_user_id", "ratchet_sessions", ["owner_user_id"], unique=False) + op.create_index("ix_ratchet_sessions_owner_device_uid", "ratchet_sessions", ["owner_device_uid"], unique=False) + op.create_index("ix_ratchet_sessions_peer_address", "ratchet_sessions", ["peer_address"], unique=False) + op.create_index("ix_ratchet_sessions_peer_device_uid", "ratchet_sessions", ["peer_device_uid"], unique=False) + + op.create_table( + "ratchet_skipped_keys", + sa.Column("id", sa.String(length=36), primary_key=True), + sa.Column("owner_device_uid", sa.String(length=36), sa.ForeignKey("devices.id", ondelete="CASCADE"), nullable=False), + sa.Column("peer_address", sa.String(length=320), nullable=False), + sa.Column("peer_device_uid", sa.String(length=64), nullable=False), + sa.Column("dh_pub_b64", sa.String(length=128), nullable=False), + sa.Column("msg_n", sa.Integer(), nullable=False), + sa.Column("mk_encrypted_b64", sa.String(length=512), nullable=False), + sa.Column("mk_nonce_b64", sa.String(length=128), nullable=False), + sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("used_at", sa.DateTime(timezone=True), nullable=True), + sa.UniqueConstraint( + "owner_device_uid", + "peer_address", + "peer_device_uid", + "dh_pub_b64", + "msg_n", + name="uq_ratchet_skipped_owner_peer_dh_n", + ), + ) + op.create_index("ix_ratchet_skipped_keys_owner_device_uid", "ratchet_skipped_keys", ["owner_device_uid"], unique=False) + op.create_index("ix_ratchet_skipped_keys_peer_address", "ratchet_skipped_keys", ["peer_address"], unique=False) + op.create_index("ix_ratchet_skipped_keys_peer_device_uid", "ratchet_skipped_keys", ["peer_device_uid"], unique=False) + + with op.batch_alter_table("message_events") as batch: + batch.add_column(sa.Column("encryption_mode", sa.String(length=32), nullable=False, server_default="sealedbox_v0_2a")) + op.create_index("ix_message_events_encryption_mode", "message_events", ["encryption_mode"], unique=False) + + +def downgrade() -> None: + op.drop_index("ix_message_events_encryption_mode", table_name="message_events") + with op.batch_alter_table("message_events") as batch: + batch.drop_column("encryption_mode") + + op.drop_index("ix_ratchet_skipped_keys_peer_device_uid", table_name="ratchet_skipped_keys") + op.drop_index("ix_ratchet_skipped_keys_peer_address", table_name="ratchet_skipped_keys") + op.drop_index("ix_ratchet_skipped_keys_owner_device_uid", table_name="ratchet_skipped_keys") + op.drop_table("ratchet_skipped_keys") + + op.drop_index("ix_ratchet_sessions_peer_device_uid", table_name="ratchet_sessions") + op.drop_index("ix_ratchet_sessions_peer_address", table_name="ratchet_sessions") + op.drop_index("ix_ratchet_sessions_owner_device_uid", table_name="ratchet_sessions") + op.drop_index("ix_ratchet_sessions_owner_user_id", table_name="ratchet_sessions") + op.drop_table("ratchet_sessions") + + op.drop_index("ix_device_one_time_prekeys_device_consumed", table_name="device_one_time_prekeys") + op.drop_index("ix_device_one_time_prekeys_device_id", table_name="device_one_time_prekeys") + op.drop_table("device_one_time_prekeys") + + op.drop_index("ix_device_signed_prekeys_device_exp_revoked", table_name="device_signed_prekeys") + op.drop_index("ix_device_signed_prekeys_device_id", table_name="device_signed_prekeys") + op.drop_table("device_signed_prekeys") diff --git a/server/migrations/versions/20260224_0007_group_conversation_core.py b/server/migrations/versions/20260224_0007_group_conversation_core.py new file mode 100644 index 0000000..928706c --- /dev/null +++ b/server/migrations/versions/20260224_0007_group_conversation_core.py @@ -0,0 +1,119 @@ +"""group conversation core + +Revision ID: 20260224_0007 +Revises: 20260223_0006 +Create Date: 2026-02-24 10:30:00 +""" + +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa + + +revision: str = "20260224_0007" +down_revision: str | None = "20260223_0006" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.add_column( + "conversations", + sa.Column("conversation_type", sa.String(length=16), nullable=False, server_default="direct"), + ) + op.add_column( + "conversations", + sa.Column("group_uid", sa.String(length=64), nullable=True), + ) + op.add_column( + "conversations", + sa.Column("group_name", sa.String(length=128), nullable=False, server_default=""), + ) + op.add_column( + "conversations", + sa.Column("origin_server_onion", sa.String(length=255), nullable=False, server_default=""), + ) + op.add_column( + "conversations", + sa.Column("owner_address", sa.String(length=320), nullable=False, server_default=""), + ) + op.create_index("ix_conversations_conversation_type", "conversations", ["conversation_type"], unique=False) + op.create_index("ix_conversations_group_uid", "conversations", ["group_uid"], unique=True) + + op.create_table( + "conversation_members", + sa.Column("id", sa.String(length=36), primary_key=True), + sa.Column( + "conversation_id", + sa.String(length=36), + sa.ForeignKey("conversations.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column( + "member_user_id", + sa.String(length=36), + sa.ForeignKey("users.id", ondelete="SET NULL"), + nullable=True, + ), + sa.Column("member_address", sa.String(length=320), nullable=False), + sa.Column("member_server_onion", sa.String(length=255), nullable=False), + sa.Column("role", sa.String(length=16), nullable=False), + sa.Column("status", sa.String(length=16), nullable=False), + sa.Column("invited_by_address", sa.String(length=320), nullable=False), + sa.Column("invited_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("joined_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("left_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), + sa.UniqueConstraint("conversation_id", "member_address", name="uq_conversation_member_address"), + ) + op.create_index("ix_conversation_members_conversation_id", "conversation_members", ["conversation_id"], unique=False) + op.create_index("ix_conversation_members_member_user_id", "conversation_members", ["member_user_id"], unique=False) + op.create_index("ix_conversation_members_member_address", "conversation_members", ["member_address"], unique=False) + op.create_index( + "ix_conversation_members_member_server_onion", + "conversation_members", + ["member_server_onion"], + unique=False, + ) + op.create_index("ix_conversation_members_role", "conversation_members", ["role"], unique=False) + op.create_index("ix_conversation_members_status", "conversation_members", ["status"], unique=False) + + op.create_table( + "group_membership_events", + sa.Column("id", sa.String(length=36), primary_key=True), + sa.Column("group_uid", sa.String(length=64), nullable=False), + sa.Column("event_seq", sa.Integer(), nullable=False), + sa.Column("event_type", sa.String(length=24), nullable=False), + sa.Column("actor_address", sa.String(length=320), nullable=False), + sa.Column("target_address", sa.String(length=320), nullable=False), + sa.Column("payload_json", sa.JSON(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.UniqueConstraint("group_uid", "event_seq", name="uq_group_membership_event_seq"), + ) + op.create_index("ix_group_membership_events_group_uid", "group_membership_events", ["group_uid"], unique=False) + op.create_index("ix_group_membership_events_event_seq", "group_membership_events", ["event_seq"], unique=False) + op.create_index("ix_group_membership_events_event_type", "group_membership_events", ["event_type"], unique=False) + + +def downgrade() -> None: + op.drop_index("ix_group_membership_events_event_type", table_name="group_membership_events") + op.drop_index("ix_group_membership_events_event_seq", table_name="group_membership_events") + op.drop_index("ix_group_membership_events_group_uid", table_name="group_membership_events") + op.drop_table("group_membership_events") + + op.drop_index("ix_conversation_members_status", table_name="conversation_members") + op.drop_index("ix_conversation_members_role", table_name="conversation_members") + op.drop_index("ix_conversation_members_member_server_onion", table_name="conversation_members") + op.drop_index("ix_conversation_members_member_address", table_name="conversation_members") + op.drop_index("ix_conversation_members_member_user_id", table_name="conversation_members") + op.drop_index("ix_conversation_members_conversation_id", table_name="conversation_members") + op.drop_table("conversation_members") + + op.drop_index("ix_conversations_group_uid", table_name="conversations") + op.drop_index("ix_conversations_conversation_type", table_name="conversations") + op.drop_column("conversations", "owner_address") + op.drop_column("conversations", "origin_server_onion") + op.drop_column("conversations", "group_name") + op.drop_column("conversations", "group_uid") + op.drop_column("conversations", "conversation_type") diff --git a/server/migrations/versions/20260224_0008_group_call_core.py b/server/migrations/versions/20260224_0008_group_call_core.py new file mode 100644 index 0000000..cd7f559 --- /dev/null +++ b/server/migrations/versions/20260224_0008_group_call_core.py @@ -0,0 +1,105 @@ +"""group call core + +Revision ID: 20260224_0008 +Revises: 20260224_0007 +Create Date: 2026-02-24 10:55:00 +""" + +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa + + +revision: str = "20260224_0008" +down_revision: str | None = "20260224_0007" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.create_table( + "group_call_sessions", + sa.Column("id", sa.String(length=36), primary_key=True), + sa.Column("call_id", sa.String(length=36), nullable=False, unique=True), + sa.Column( + "conversation_id", + sa.String(length=36), + sa.ForeignKey("conversations.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("group_uid", sa.String(length=64), nullable=False), + sa.Column("initiator_address", sa.String(length=320), nullable=False), + sa.Column("state", sa.String(length=16), nullable=False), + sa.Column("started_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("ended_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("ring_expires_at", sa.DateTime(timezone=True), nullable=False), + ) + op.create_index("ix_group_call_sessions_call_id", "group_call_sessions", ["call_id"], unique=True) + op.create_index( + "ix_group_call_sessions_conversation_id", + "group_call_sessions", + ["conversation_id"], + unique=False, + ) + op.create_index("ix_group_call_sessions_group_uid", "group_call_sessions", ["group_uid"], unique=False) + op.create_index("ix_group_call_sessions_state", "group_call_sessions", ["state"], unique=False) + op.create_index( + "ix_group_call_sessions_ring_expires_at", + "group_call_sessions", + ["ring_expires_at"], + unique=False, + ) + + op.create_table( + "group_call_participants", + sa.Column("id", sa.String(length=36), primary_key=True), + sa.Column( + "call_id", + sa.String(length=36), + sa.ForeignKey("group_call_sessions.call_id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("member_address", sa.String(length=320), nullable=False), + sa.Column( + "local_user_id", + sa.String(length=36), + sa.ForeignKey("users.id", ondelete="SET NULL"), + nullable=True, + ), + sa.Column("state", sa.String(length=24), nullable=False), + sa.Column("invited_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("joined_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("left_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("last_signal_at", sa.DateTime(timezone=True), nullable=True), + sa.UniqueConstraint("call_id", "member_address", name="uq_group_call_participant_address"), + ) + op.create_index("ix_group_call_participants_call_id", "group_call_participants", ["call_id"], unique=False) + op.create_index( + "ix_group_call_participants_member_address", + "group_call_participants", + ["member_address"], + unique=False, + ) + op.create_index( + "ix_group_call_participants_local_user_id", + "group_call_participants", + ["local_user_id"], + unique=False, + ) + op.create_index("ix_group_call_participants_state", "group_call_participants", ["state"], unique=False) + + +def downgrade() -> None: + op.drop_index("ix_group_call_participants_state", table_name="group_call_participants") + op.drop_index("ix_group_call_participants_local_user_id", table_name="group_call_participants") + op.drop_index("ix_group_call_participants_member_address", table_name="group_call_participants") + op.drop_index("ix_group_call_participants_call_id", table_name="group_call_participants") + op.drop_table("group_call_participants") + + op.drop_index("ix_group_call_sessions_ring_expires_at", table_name="group_call_sessions") + op.drop_index("ix_group_call_sessions_state", table_name="group_call_sessions") + op.drop_index("ix_group_call_sessions_group_uid", table_name="group_call_sessions") + op.drop_index("ix_group_call_sessions_conversation_id", table_name="group_call_sessions") + op.drop_index("ix_group_call_sessions_call_id", table_name="group_call_sessions") + op.drop_table("group_call_sessions") diff --git a/server/pyproject.toml b/server/pyproject.toml index 70ec680..a2890ea 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -21,7 +21,8 @@ dependencies = [ "httpx>=0.28.1,<1.0.0", "requests>=2.32.5,<3.0.0", "websockets>=15.0.1,<16.0.0", - "PyNaCl>=1.5.0,<2.0.0" + "PyNaCl>=1.5.0,<2.0.0", + "cryptography>=44.0.0,<46.0.0" ] [project.optional-dependencies] diff --git a/server/tests/conftest.py b/server/tests/conftest.py index 445d92e..4cbd791 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -8,6 +8,7 @@ from app.config import reset_settings_cache from app.db import reset_engine +from app.security.tokens_v2 import reset_v2_token_cache @pytest.fixture() @@ -24,6 +25,7 @@ def client(tmp_path: Path) -> Generator[TestClient, None, None]: os.environ["BLACKWIRE_FEDERATION_SIGNING_PRIVATE_KEY_B64"] = "" reset_settings_cache() + reset_v2_token_cache() asyncio.run(reset_engine()) from app.main import create_app @@ -35,3 +37,4 @@ def client(tmp_path: Path) -> Generator[TestClient, None, None]: asyncio.run(reset_engine()) reset_settings_cache() + reset_v2_token_cache() diff --git a/server/tests/integration/test_v2_attachment_policy.py b/server/tests/integration/test_v2_attachment_policy.py new file mode 100644 index 0000000..f1414d0 --- /dev/null +++ b/server/tests/integration/test_v2_attachment_policy.py @@ -0,0 +1,141 @@ +from datetime import UTC, datetime +from uuid import uuid4 + +from app.config import get_settings +from app.schemas.v2_device import UserDeviceLookupV2 + + +def _auth_header(token: str) -> dict[str, str]: + return {"Authorization": f"Bearer {token}"} + + +def _register_v2_user(client, username: str, password: str = "password123") -> dict: + response = client.post("/api/v2/auth/register", json={"username": username, "password": password}) + assert response.status_code == 201, response.text + return response.json() + + +def _register_v2_device(client, bootstrap_token: str, label: str) -> dict: + sign_key = (f"sign-{label}-{uuid4().hex}" * 3)[:64] + dh_key = (f"dh-{label}-{uuid4().hex}" * 3)[:64] + response = client.post( + "/api/v2/devices/register", + headers=_auth_header(bootstrap_token), + json={ + "label": label, + "pub_sign_key": sign_key, + "pub_dh_key": dh_key, + }, + ) + assert response.status_code == 200, response.text + return response.json() + + +def test_v2_resolve_devices_includes_attachment_policy(client) -> None: + alice = _register_v2_user(client, "alice_v2_attach_local") + bob = _register_v2_user(client, "bob_v2_attach_local") + alice_device = _register_v2_device(client, alice["tokens"]["bootstrap_token"], "alice-device") + _register_v2_device(client, bob["tokens"]["bootstrap_token"], "bob-device") + + lookup = client.get( + "/api/v2/users/resolve-devices", + headers=_auth_header(alice_device["tokens"]["access_token"]), + params={"peer_address": "bob_v2_attach_local@local.invalid"}, + ) + assert lookup.status_code == 200, lookup.text + body = lookup.json() + assert body["attachment_inline_max_bytes"] == get_settings().effective_attachment_inline_max_bytes() + assert body["max_ciphertext_bytes"] == get_settings().effective_max_ciphertext_bytes() + assert body["attachment_policy_source"] == "local" + + +def test_v2_federation_user_devices_exposes_attachment_policy(client) -> None: + bob = _register_v2_user(client, "bob_v2_attach_federation") + _register_v2_device(client, bob["tokens"]["bootstrap_token"], "bob-device") + + response = client.get("/api/v2/federation/users/bob_v2_attach_federation/devices") + assert response.status_code == 200, response.text + body = response.json() + assert body["attachment_inline_max_bytes"] == get_settings().effective_attachment_inline_max_bytes() + assert body["max_ciphertext_bytes"] == get_settings().effective_max_ciphertext_bytes() + assert body["attachment_policy_source"] == "local" + + +def test_v2_federation_well_known_exposes_attachment_policy(client) -> None: + response = client.get("/api/v2/federation/well-known") + assert response.status_code == 200, response.text + body = response.json() + assert body["attachment_inline_max_bytes"] == get_settings().effective_attachment_inline_max_bytes() + assert body["max_ciphertext_bytes"] == get_settings().effective_max_ciphertext_bytes() + assert body["attachment_hard_ceiling_bytes"] == get_settings().attachment_hard_ceiling_bytes + + +def test_v2_resolve_devices_remote_policy_missing_falls_back_local(client, monkeypatch) -> None: + alice = _register_v2_user(client, "alice_v2_attach_remote") + alice_device = _register_v2_device(client, alice["tokens"]["bootstrap_token"], "alice-device") + + now = datetime.now(UTC).isoformat() + + async def _fake_remote_lookup(_peer_onion: str, username: str) -> UserDeviceLookupV2: + return UserDeviceLookupV2.model_validate( + { + "username": username, + "peer_address": f"{username}@remote.invalid", + "devices": [ + { + "device_uid": "remote-device-1", + "user_id": "remote-user-1", + "label": "remote", + "pub_sign_key": "C" * 44, + "pub_dh_key": "D" * 44, + "status": "active", + "supported_message_modes": ["sealedbox_v0_2a"], + "created_at": now, + "last_seen_at": now, + "revoked_at": None, + } + ], + "attachment_inline_max_bytes": 0, + "max_ciphertext_bytes": 0, + "attachment_policy_source": "remote", + } + ) + + monkeypatch.setattr( + "app.services.federation_client.federation_client.get_remote_user_devices_v2", + _fake_remote_lookup, + ) + + response = client.get( + "/api/v2/users/resolve-devices", + headers=_auth_header(alice_device["tokens"]["access_token"]), + params={"peer_address": "bob_remote@remote.invalid"}, + ) + assert response.status_code == 200, response.text + body = response.json() + assert body["attachment_policy_source"] == "fallback_local" + assert body["attachment_inline_max_bytes"] == get_settings().effective_attachment_inline_max_bytes() + assert body["max_ciphertext_bytes"] == get_settings().effective_max_ciphertext_bytes() + + +def test_v2_messages_send_rejects_payload_above_body_cap(client) -> None: + alice = _register_v2_user(client, "alice_v2_attach_body") + alice_device = _register_v2_device(client, alice["tokens"]["bootstrap_token"], "alice-device") + + settings = get_settings() + original_cap = settings.max_client_message_body_bytes + settings.max_client_message_body_bytes = 128 + try: + payload = '{"conversation_id":"' + ("x" * 256) + '"}' + response = client.post( + "/api/v2/messages/send", + headers={ + **_auth_header(alice_device["tokens"]["access_token"]), + "Content-Type": "application/json", + }, + content=payload.encode("utf-8"), + ) + finally: + settings.max_client_message_body_bytes = original_cap + + assert response.status_code == 413, response.text diff --git a/server/tests/integration/test_v2_group_dm_calls.py b/server/tests/integration/test_v2_group_dm_calls.py new file mode 100644 index 0000000..c1b0799 --- /dev/null +++ b/server/tests/integration/test_v2_group_dm_calls.py @@ -0,0 +1,726 @@ +import asyncio +import base64 +import hashlib +import os +import time +import uuid +from collections.abc import Generator +from pathlib import Path +from typing import Any + +import pytest +from fastapi.testclient import TestClient +from nacl import encoding, public, signing + +from app.config import get_settings, reset_settings_cache +from app.db import reset_engine +from app.security.tokens_v2 import reset_v2_token_cache + + +def _auth_header(token: str) -> dict[str, str]: + return {"Authorization": f"Bearer {token}"} + + +def _b64(data: bytes) -> str: + return base64.b64encode(data).decode("ascii") + + +def _sha256_hex(data: bytes) -> str: + return hashlib.sha256(data).hexdigest() + + +def _aggregate_chain_hash(sender_prev_hash: str, client_message_id: str, sent_at_ms: int, hash_material: list[str]) -> str: + aggregate = _sha256_hex("|".join(sorted(hash_material)).encode("utf-8")) + chain_data = "\n".join([sender_prev_hash, client_message_id, str(sent_at_ms), aggregate]).encode("utf-8") + return _sha256_hex(chain_data) + + +def _canonical_message_signature_string( + sender_address: str, + sender_device_uid: str, + recipient_user_address: str, + recipient_device_uid: str, + client_message_id: str, + sent_at_ms: int, + sender_prev_hash: str, + sender_chain_hash: str, + ciphertext_hash: str, + aad_hash: str, +) -> bytes: + canonical = "\n".join( + [ + sender_address, + sender_device_uid, + recipient_user_address, + recipient_device_uid, + client_message_id, + str(sent_at_ms), + sender_prev_hash, + sender_chain_hash, + ciphertext_hash, + aad_hash, + ] + ) + return canonical.encode("utf-8") + + +def _new_device_material(label: str) -> dict[str, Any]: + sign_sk = signing.SigningKey.generate() + sign_pk_b64 = sign_sk.verify_key.encode(encoder=encoding.Base64Encoder).decode("utf-8") + dh_sk = public.PrivateKey.generate() + dh_pk_b64 = _b64(bytes(dh_sk.public_key)) + return { + "label": label, + "sign_sk": sign_sk, + "sign_pk_b64": sign_pk_b64, + "dh_pk_b64": dh_pk_b64, + } + + +def _register_v2_user(client: TestClient, username: str, password: str = "password123") -> dict[str, Any]: + response = client.post("/api/v2/auth/register", json={"username": username, "password": password}) + assert response.status_code == 201, response.text + return response.json() + + +def _register_device_v2(client: TestClient, bootstrap_token: str, material: dict[str, Any]) -> dict[str, Any]: + response = client.post( + "/api/v2/devices/register", + headers=_auth_header(bootstrap_token), + json={ + "label": material["label"], + "pub_sign_key": material["sign_pk_b64"], + "pub_dh_key": material["dh_pk_b64"], + }, + ) + assert response.status_code == 200, response.text + return response.json() + + +def _build_group_send_payload( + *, + conversation_id: str, + sender_address: str, + sender_device_uid: str, + sender_sign_material: dict[str, Any], + sender_prev_hash: str, + targets: list[dict[str, Any]], + plaintext: bytes, +) -> dict[str, Any]: + client_message_id = str(uuid.uuid4()) + sent_at_ms = int(time.time() * 1000) + envelopes: list[dict[str, Any]] = [] + hash_material: list[str] = [] + for target in targets: + recipient_address = str(target["member_address"]).strip().lower() + recipient_device_uid = str(target["device"]["device_uid"]) + ciphertext_b64 = _b64(plaintext) + ciphertext_hash = _sha256_hex(plaintext) + aad_hash = _sha256_hex(b"") + hash_material.append(f"{recipient_device_uid}:{ciphertext_hash}:{aad_hash}") + envelopes.append( + { + "recipient_user_address": recipient_address, + "recipient_device_uid": recipient_device_uid, + "ciphertext_b64": ciphertext_b64, + "aad_b64": None, + "signature_b64": "", + "sender_device_pubkey": sender_sign_material["sign_pk_b64"], + } + ) + sender_chain_hash = _aggregate_chain_hash(sender_prev_hash, client_message_id, sent_at_ms, hash_material) + for envelope in envelopes: + canonical = _canonical_message_signature_string( + sender_address=sender_address, + sender_device_uid=sender_device_uid, + recipient_user_address=envelope["recipient_user_address"], + recipient_device_uid=envelope["recipient_device_uid"], + client_message_id=client_message_id, + sent_at_ms=sent_at_ms, + sender_prev_hash=sender_prev_hash, + sender_chain_hash=sender_chain_hash, + ciphertext_hash=_sha256_hex(base64.b64decode(envelope["ciphertext_b64"])), + aad_hash=_sha256_hex(b""), + ) + envelope["signature_b64"] = _b64(sender_sign_material["sign_sk"].sign(canonical).signature) + return { + "conversation_id": conversation_id, + "client_message_id": client_message_id, + "sent_at_ms": sent_at_ms, + "sender_prev_hash": sender_prev_hash, + "sender_chain_hash": sender_chain_hash, + "envelopes": envelopes, + } + + +def _receive_until(ws, predicate, max_events: int = 24) -> dict[str, Any]: + seen: list[str] = [] + for _ in range(max_events): + event = ws.receive_json() + seen.append(str(event.get("type", ""))) + if predicate(event): + return event + raise AssertionError(f"Expected websocket event not observed; seen={seen}") + + +@pytest.fixture() +def group_client(tmp_path: Path) -> Generator[TestClient, None, None]: + db_path = tmp_path / "group_v2.db" + os.environ["BLACKWIRE_ENVIRONMENT"] = "test" + os.environ["BLACKWIRE_DATABASE_URL"] = f"sqlite+aiosqlite:///{db_path.as_posix()}" + os.environ["BLACKWIRE_AUTO_CREATE_TABLES"] = "true" + os.environ["BLACKWIRE_JWT_SECRET_KEY"] = "test-secret-key-with-at-least-32-bytes" + os.environ["BLACKWIRE_RATE_LIMIT_PER_MINUTE"] = "10000" + os.environ["BLACKWIRE_VOICE_CALL_RING_TIMEOUT_SECONDS"] = "2" + os.environ["BLACKWIRE_TOR_ENABLED"] = "false" + os.environ["BLACKWIRE_FEDERATION_SERVER_ONION"] = "local.invalid" + os.environ["BLACKWIRE_FEDERATION_SIGNING_PRIVATE_KEY_B64"] = "" + os.environ["BLACKWIRE_ENABLE_GROUP_DM_V2C"] = "true" + os.environ["BLACKWIRE_ENABLE_GROUP_CALL_V2C"] = "true" + + reset_settings_cache() + reset_v2_token_cache() + asyncio.run(reset_engine()) + + from app.main import create_app + + app = create_app() + settings = get_settings() + from app.services.group_call_service import group_call_service + from app.services.group_conversation_service import group_conversation_service + from app.services.message_service_v2 import message_service_v2 + + group_conversation_service.settings = settings + group_call_service.settings = settings + message_service_v2.settings = settings + with TestClient(app) as test_client: + yield test_client + + asyncio.run(reset_engine()) + reset_settings_cache() + reset_v2_token_cache() + + +def test_v2_group_dm_no_prejoin_history(group_client: TestClient) -> None: + alice_register = _register_v2_user(group_client, "alice_group_owner") + bob_register = _register_v2_user(group_client, "bob_group_member") + carol_register = _register_v2_user(group_client, "carol_group_member") + + alice_material = _new_device_material("alice-group-device") + bob_material = _new_device_material("bob-group-device") + carol_material = _new_device_material("carol-group-device") + + alice_tokens = _register_device_v2(group_client, alice_register["tokens"]["bootstrap_token"], alice_material)["tokens"] + bob_tokens = _register_device_v2(group_client, bob_register["tokens"]["bootstrap_token"], bob_material)["tokens"] + carol_tokens = _register_device_v2(group_client, carol_register["tokens"]["bootstrap_token"], carol_material)["tokens"] + + create_group = group_client.post( + "/api/v2/conversations/group", + headers=_auth_header(alice_tokens["access_token"]), + json={"name": "Ops Room", "member_addresses": ["bob_group_member@local.invalid"]}, + ) + assert create_group.status_code == 200, create_group.text + conversation_id = create_group.json()["id"] + assert create_group.json()["conversation_type"] == "group" + + bob_invited = group_client.get( + f"/api/v2/conversations/{conversation_id}/messages", + headers=_auth_header(bob_tokens["access_token"]), + ) + assert bob_invited.status_code == 403 + + bob_accept = group_client.post( + f"/api/v2/conversations/{conversation_id}/invites/accept", + headers=_auth_header(bob_tokens["access_token"]), + ) + assert bob_accept.status_code == 200, bob_accept.text + assert bob_accept.json()["status"] == "active" + + recipients_1 = group_client.get( + f"/api/v2/conversations/{conversation_id}/recipients", + headers=_auth_header(alice_tokens["access_token"]), + ) + assert recipients_1.status_code == 200, recipients_1.text + targets_1 = recipients_1.json()["recipients"] + assert len(targets_1) == 1 + assert targets_1[0]["member_address"] == "bob_group_member@local.invalid" + + first_payload = _build_group_send_payload( + conversation_id=conversation_id, + sender_address="alice_group_owner@local.invalid", + sender_device_uid=alice_tokens["device_uid"], + sender_sign_material=alice_material, + sender_prev_hash="", + targets=targets_1, + plaintext=b"hello-group-before-carol", + ) + send_1 = group_client.post( + "/api/v2/messages/send", + headers=_auth_header(alice_tokens["access_token"]), + json=first_payload, + ) + assert send_1.status_code == 200, send_1.text + first_chain = send_1.json()["message"]["sender_chain_hash"] + + invite_carol = group_client.post( + f"/api/v2/conversations/{conversation_id}/members/invite", + headers=_auth_header(alice_tokens["access_token"]), + json={"member_addresses": ["carol_group_member@local.invalid"]}, + ) + assert invite_carol.status_code == 200, invite_carol.text + + carol_accept = group_client.post( + f"/api/v2/conversations/{conversation_id}/invites/accept", + headers=_auth_header(carol_tokens["access_token"]), + ) + assert carol_accept.status_code == 200, carol_accept.text + + carol_prejoin_history = group_client.get( + f"/api/v2/conversations/{conversation_id}/messages", + headers=_auth_header(carol_tokens["access_token"]), + ) + assert carol_prejoin_history.status_code == 200, carol_prejoin_history.text + assert carol_prejoin_history.json() == [] + + recipients_2 = group_client.get( + f"/api/v2/conversations/{conversation_id}/recipients", + headers=_auth_header(alice_tokens["access_token"]), + ) + assert recipients_2.status_code == 200, recipients_2.text + targets_2 = recipients_2.json()["recipients"] + recipient_addresses = sorted(item["member_address"] for item in targets_2) + assert recipient_addresses == ["bob_group_member@local.invalid", "carol_group_member@local.invalid"] + + second_payload = _build_group_send_payload( + conversation_id=conversation_id, + sender_address="alice_group_owner@local.invalid", + sender_device_uid=alice_tokens["device_uid"], + sender_sign_material=alice_material, + sender_prev_hash=first_chain, + targets=targets_2, + plaintext=b"hello-group-after-carol", + ) + send_2 = group_client.post( + "/api/v2/messages/send", + headers=_auth_header(alice_tokens["access_token"]), + json=second_payload, + ) + assert send_2.status_code == 200, send_2.text + + carol_messages = group_client.get( + f"/api/v2/conversations/{conversation_id}/messages", + headers=_auth_header(carol_tokens["access_token"]), + ) + assert carol_messages.status_code == 200, carol_messages.text + assert len(carol_messages.json()) == 1 + + +def test_v2_group_call_offer_join_leave_end(group_client: TestClient) -> None: + alice_register = _register_v2_user(group_client, "alice_group_call") + bob_register = _register_v2_user(group_client, "bob_group_call") + alice_material = _new_device_material("alice-call-device") + bob_material = _new_device_material("bob-call-device") + alice_tokens = _register_device_v2(group_client, alice_register["tokens"]["bootstrap_token"], alice_material)["tokens"] + bob_tokens = _register_device_v2(group_client, bob_register["tokens"]["bootstrap_token"], bob_material)["tokens"] + + create_group = group_client.post( + "/api/v2/conversations/group", + headers=_auth_header(alice_tokens["access_token"]), + json={"name": "Call Room", "member_addresses": ["bob_group_call@local.invalid"]}, + ) + assert create_group.status_code == 200, create_group.text + conversation_id = create_group.json()["id"] + + bob_accept = group_client.post( + f"/api/v2/conversations/{conversation_id}/invites/accept", + headers=_auth_header(bob_tokens["access_token"]), + ) + assert bob_accept.status_code == 200, bob_accept.text + + with group_client.websocket_connect("/api/v2/ws", headers=_auth_header(alice_tokens["access_token"])) as alice_ws: + with group_client.websocket_connect("/api/v2/ws", headers=_auth_header(bob_tokens["access_token"])) as bob_ws: + alice_ws.send_json({"type": "call.offer", "conversation_id": conversation_id}) + + incoming = _receive_until( + bob_ws, + lambda event: event.get("type") == "call.group.incoming", + ) + call_id = str(incoming["call_id"]) + + bob_ws.send_json({"type": "call.accept", "call_id": call_id}) + + active_state = _receive_until( + alice_ws, + lambda event: event.get("type") == "call.group.state" + and event.get("call_id") == call_id + and event.get("state") == "active", + ) + assert active_state["group_uid"] == create_group.json()["group_uid"] + + bob_ws.send_json({"type": "call.end", "call_id": call_id, "reason": "left"}) + bob_ended = _receive_until( + bob_ws, + lambda event: event.get("type") == "call.group.ended" and event.get("call_id") == call_id, + ) + assert bob_ended["reason"] == "left" + post_leave_state = _receive_until( + alice_ws, + lambda event: event.get("type") == "call.group.state" and event.get("call_id") == call_id, + ) + assert post_leave_state["state"] != "ended" + + # Pool semantics: leaving user can re-enter existing active group call. + bob_ws.send_json({"type": "call.offer", "conversation_id": conversation_id}) + bob_rejoin_state = _receive_until( + bob_ws, + lambda event: event.get("type") == "call.group.state" + and event.get("call_id") == call_id + and event.get("state") == "active", + ) + assert bob_rejoin_state["group_uid"] == create_group.json()["group_uid"] + + alice_ws.send_json({"type": "call.end", "call_id": call_id, "reason": "ended"}) + ended = _receive_until( + alice_ws, + lambda event: event.get("type") == "call.group.ended" and event.get("call_id") == call_id, + ) + assert ended["reason"] == "ended" + + +def test_v2_group_call_audio_local_forwarding(group_client: TestClient) -> None: + alice_register = _register_v2_user(group_client, "alice_group_audio") + bob_register = _register_v2_user(group_client, "bob_group_audio") + alice_material = _new_device_material("alice-audio-device") + bob_material = _new_device_material("bob-audio-device") + alice_tokens = _register_device_v2(group_client, alice_register["tokens"]["bootstrap_token"], alice_material)["tokens"] + bob_tokens = _register_device_v2(group_client, bob_register["tokens"]["bootstrap_token"], bob_material)["tokens"] + + create_group = group_client.post( + "/api/v2/conversations/group", + headers=_auth_header(alice_tokens["access_token"]), + json={"name": "Audio Room", "member_addresses": ["bob_group_audio@local.invalid"]}, + ) + assert create_group.status_code == 200, create_group.text + conversation_id = create_group.json()["id"] + + bob_accept = group_client.post( + f"/api/v2/conversations/{conversation_id}/invites/accept", + headers=_auth_header(bob_tokens["access_token"]), + ) + assert bob_accept.status_code == 200, bob_accept.text + + with group_client.websocket_connect("/api/v2/ws", headers=_auth_header(alice_tokens["access_token"])) as alice_ws: + with group_client.websocket_connect("/api/v2/ws", headers=_auth_header(bob_tokens["access_token"])) as bob_ws: + alice_ws.send_json({"type": "call.offer", "conversation_id": conversation_id}) + incoming = _receive_until( + bob_ws, + lambda event: event.get("type") == "call.group.incoming", + ) + call_id = str(incoming["call_id"]) + + bob_ws.send_json({"type": "call.accept", "call_id": call_id}) + _receive_until( + alice_ws, + lambda event: event.get("type") == "call.group.state" + and event.get("call_id") == call_id + and event.get("state") == "active", + ) + + pcm_b64 = _b64(b"\x00" * 320) + alice_ws.send_json( + { + "type": "call.audio", + "call_id": call_id, + "sequence": 1, + "pcm_b64": pcm_b64, + } + ) + audio_event = _receive_until( + bob_ws, + lambda event: event.get("type") == "call.audio" and event.get("call_id") == call_id, + ) + assert audio_event["pcm_b64"] == pcm_b64 + + +def test_v2_group_rename_pushes_ws_event_to_members(group_client: TestClient) -> None: + alice_register = _register_v2_user(group_client, "alice_group_rename") + bob_register = _register_v2_user(group_client, "bob_group_rename") + + alice_material = _new_device_material("alice-group-rename-device") + bob_material = _new_device_material("bob-group-rename-device") + + alice_tokens = _register_device_v2(group_client, alice_register["tokens"]["bootstrap_token"], alice_material)["tokens"] + bob_tokens = _register_device_v2(group_client, bob_register["tokens"]["bootstrap_token"], bob_material)["tokens"] + + create_group = group_client.post( + "/api/v2/conversations/group", + headers=_auth_header(alice_tokens["access_token"]), + json={"name": "Before Rename", "member_addresses": ["bob_group_rename@local.invalid"]}, + ) + assert create_group.status_code == 200, create_group.text + conversation_id = create_group.json()["id"] + + bob_accept = group_client.post( + f"/api/v2/conversations/{conversation_id}/invites/accept", + headers=_auth_header(bob_tokens["access_token"]), + ) + assert bob_accept.status_code == 200, bob_accept.text + + with group_client.websocket_connect("/api/v2/ws", headers=_auth_header(alice_tokens["access_token"])) as alice_ws: + with group_client.websocket_connect("/api/v2/ws", headers=_auth_header(bob_tokens["access_token"])) as bob_ws: + rename = group_client.post( + f"/api/v2/conversations/{conversation_id}/rename", + headers=_auth_header(alice_tokens["access_token"]), + json={"name": "After Rename"}, + ) + assert rename.status_code == 200, rename.text + assert rename.json()["group_name"] == "After Rename" + + alice_event = _receive_until( + alice_ws, + lambda event: event.get("type") == "conversation.group.renamed" + and event.get("conversation_id") == conversation_id, + ) + bob_event = _receive_until( + bob_ws, + lambda event: event.get("type") == "conversation.group.renamed" + and event.get("conversation_id") == conversation_id, + ) + + assert alice_event["group_name"] == "After Rename" + assert bob_event["group_name"] == "After Rename" + assert alice_event["actor_address"] == "alice_group_rename@local.invalid" + assert bob_event["actor_address"] == "alice_group_rename@local.invalid" + + +def test_v2_group_create_skips_owner_alias_and_keeps_invite_capacity(group_client: TestClient) -> None: + alice_register = _register_v2_user(group_client, "alice_group_alias") + bob_register = _register_v2_user(group_client, "bob_group_alias") + carol_register = _register_v2_user(group_client, "carol_group_alias") + + alice_material = _new_device_material("alice-group-alias-device") + bob_material = _new_device_material("bob-group-alias-device") + carol_material = _new_device_material("carol-group-alias-device") + + alice_tokens = _register_device_v2(group_client, alice_register["tokens"]["bootstrap_token"], alice_material)["tokens"] + _register_device_v2(group_client, bob_register["tokens"]["bootstrap_token"], bob_material) + _register_device_v2(group_client, carol_register["tokens"]["bootstrap_token"], carol_material) + + create_group = group_client.post( + "/api/v2/conversations/group", + headers=_auth_header(alice_tokens["access_token"]), + json={ + "name": "Alias Safety Room", + "member_addresses": [ + "alice_group_alias@local.invalid", + "bob_group_alias@local.invalid", + ], + }, + ) + assert create_group.status_code == 200, create_group.text + payload = create_group.json() + assert payload["conversation_type"] == "group" + assert payload["can_manage_members"] is True + assert payload["member_count"] == 2 + + conversation_id = payload["id"] + members = group_client.get( + f"/api/v2/conversations/{conversation_id}/members", + headers=_auth_header(alice_tokens["access_token"]), + ) + assert members.status_code == 200, members.text + member_rows = members.json() + assert len(member_rows) == 2 + owner_rows = [row for row in member_rows if row["role"] == "owner"] + assert len(owner_rows) == 1 + assert owner_rows[0]["status"] == "active" + + invite = group_client.post( + f"/api/v2/conversations/{conversation_id}/members/invite", + headers=_auth_header(alice_tokens["access_token"]), + json={"member_addresses": ["carol_group_alias@local.invalid"]}, + ) + assert invite.status_code == 200, invite.text + + +def test_v2_group_send_allows_sender_fallback_when_no_other_active_targets(group_client: TestClient) -> None: + alice_register = _register_v2_user(group_client, "alice_group_fallback") + bob_register = _register_v2_user(group_client, "bob_group_fallback") + + alice_material = _new_device_material("alice-group-fallback-device") + bob_material = _new_device_material("bob-group-fallback-device") + + alice_tokens = _register_device_v2(group_client, alice_register["tokens"]["bootstrap_token"], alice_material)["tokens"] + _register_device_v2(group_client, bob_register["tokens"]["bootstrap_token"], bob_material) + + create_group = group_client.post( + "/api/v2/conversations/group", + headers=_auth_header(alice_tokens["access_token"]), + json={"name": "Fallback Room", "member_addresses": ["bob_group_fallback@local.invalid"]}, + ) + assert create_group.status_code == 200, create_group.text + conversation_id = create_group.json()["id"] + + # Bob is invited but not active yet; server should still accept a sender-self fallback copy. + self_target = [ + { + "member_address": "alice_group_fallback@local.invalid", + "device": {"device_uid": alice_tokens["device_uid"]}, + } + ] + payload = _build_group_send_payload( + conversation_id=conversation_id, + sender_address="alice_group_fallback@local.invalid", + sender_device_uid=alice_tokens["device_uid"], + sender_sign_material=alice_material, + sender_prev_hash="", + targets=self_target, + plaintext=b"owner-fallback-message", + ) + send = group_client.post( + "/api/v2/messages/send", + headers=_auth_header(alice_tokens["access_token"]), + json=payload, + ) + assert send.status_code == 200, send.text + + alice_messages = group_client.get( + f"/api/v2/conversations/{conversation_id}/messages", + headers=_auth_header(alice_tokens["access_token"]), + ) + assert alice_messages.status_code == 200, alice_messages.text + assert len(alice_messages.json()) == 1 + + +def test_v2_group_owner_leave_transfers_to_oldest_eligible_member(group_client: TestClient) -> None: + alice_register = _register_v2_user(group_client, "alice_leave_owner") + bob_register = _register_v2_user(group_client, "bob_leave_owner") + carol_register = _register_v2_user(group_client, "carol_leave_owner") + + alice_material = _new_device_material("alice-leave-owner-device") + bob_material = _new_device_material("bob-leave-owner-device") + carol_material = _new_device_material("carol-leave-owner-device") + + alice_tokens = _register_device_v2(group_client, alice_register["tokens"]["bootstrap_token"], alice_material)["tokens"] + bob_tokens = _register_device_v2(group_client, bob_register["tokens"]["bootstrap_token"], bob_material)["tokens"] + _register_device_v2(group_client, carol_register["tokens"]["bootstrap_token"], carol_material) + + create_group = group_client.post( + "/api/v2/conversations/group", + headers=_auth_header(alice_tokens["access_token"]), + json={ + "name": "Owner Transfer Room", + "member_addresses": ["bob_leave_owner@local.invalid", "carol_leave_owner@local.invalid"], + }, + ) + assert create_group.status_code == 200, create_group.text + conversation_id = create_group.json()["id"] + + bob_accept = group_client.post( + f"/api/v2/conversations/{conversation_id}/invites/accept", + headers=_auth_header(bob_tokens["access_token"]), + ) + assert bob_accept.status_code == 200, bob_accept.text + + leave_owner = group_client.post( + f"/api/v2/conversations/{conversation_id}/leave", + headers=_auth_header(alice_tokens["access_token"]), + ) + assert leave_owner.status_code == 200, leave_owner.text + assert leave_owner.json()["status"] == "left" + + bob_conversations = group_client.get( + "/api/v2/conversations", + headers=_auth_header(bob_tokens["access_token"]), + ) + assert bob_conversations.status_code == 200, bob_conversations.text + matching = [row for row in bob_conversations.json() if row["id"] == conversation_id] + assert len(matching) == 1 + assert matching[0]["owner_address"] == "bob_leave_owner@local.invalid" + assert matching[0]["can_manage_members"] is True + + members = group_client.get( + f"/api/v2/conversations/{conversation_id}/members", + headers=_auth_header(bob_tokens["access_token"]), + ) + assert members.status_code == 200, members.text + member_by_address = {row["member_address"]: row for row in members.json()} + assert member_by_address["alice_leave_owner@local.invalid"]["status"] == "left" + assert member_by_address["alice_leave_owner@local.invalid"]["role"] == "member" + assert member_by_address["bob_leave_owner@local.invalid"]["status"] == "active" + assert member_by_address["bob_leave_owner@local.invalid"]["role"] == "owner" + + +def test_v2_group_owner_leave_promotes_oldest_invited_when_no_active_non_owner(group_client: TestClient) -> None: + alice_register = _register_v2_user(group_client, "alice_leave_invited") + bob_register = _register_v2_user(group_client, "bob_leave_invited") + + alice_material = _new_device_material("alice-leave-invited-device") + bob_material = _new_device_material("bob-leave-invited-device") + + alice_tokens = _register_device_v2(group_client, alice_register["tokens"]["bootstrap_token"], alice_material)["tokens"] + bob_tokens = _register_device_v2(group_client, bob_register["tokens"]["bootstrap_token"], bob_material)["tokens"] + + create_group = group_client.post( + "/api/v2/conversations/group", + headers=_auth_header(alice_tokens["access_token"]), + json={ + "name": "Owner Transfer Invited Room", + "member_addresses": ["bob_leave_invited@local.invalid"], + }, + ) + assert create_group.status_code == 200, create_group.text + conversation_id = create_group.json()["id"] + + leave_owner = group_client.post( + f"/api/v2/conversations/{conversation_id}/leave", + headers=_auth_header(alice_tokens["access_token"]), + ) + assert leave_owner.status_code == 200, leave_owner.text + assert leave_owner.json()["status"] == "left" + + members = group_client.get( + f"/api/v2/conversations/{conversation_id}/members", + headers=_auth_header(bob_tokens["access_token"]), + ) + assert members.status_code == 200, members.text + member_by_address = {row["member_address"]: row for row in members.json()} + assert member_by_address["bob_leave_invited@local.invalid"]["role"] == "owner" + assert member_by_address["bob_leave_invited@local.invalid"]["status"] == "active" + + +def test_v2_group_owner_leave_deletes_group_when_no_eligible_non_owner(group_client: TestClient) -> None: + alice_register = _register_v2_user(group_client, "alice_leave_delete") + alice_material = _new_device_material("alice-leave-delete-device") + alice_tokens = _register_device_v2(group_client, alice_register["tokens"]["bootstrap_token"], alice_material)["tokens"] + + create_group = group_client.post( + "/api/v2/conversations/group", + headers=_auth_header(alice_tokens["access_token"]), + json={ + "name": "Owner Delete Room", + "member_addresses": [], + }, + ) + assert create_group.status_code == 200, create_group.text + conversation_id = create_group.json()["id"] + + leave_owner = group_client.post( + f"/api/v2/conversations/{conversation_id}/leave", + headers=_auth_header(alice_tokens["access_token"]), + ) + assert leave_owner.status_code == 200, leave_owner.text + assert leave_owner.json()["status"] == "left" + + list_after = group_client.get( + "/api/v2/conversations", + headers=_auth_header(alice_tokens["access_token"]), + ) + assert list_after.status_code == 200, list_after.text + assert all(row["id"] != conversation_id for row in list_after.json()) + + members = group_client.get( + f"/api/v2/conversations/{conversation_id}/members", + headers=_auth_header(alice_tokens["access_token"]), + ) + assert members.status_code == 404 diff --git a/server/tests/integration/test_v2_local_alias_routing.py b/server/tests/integration/test_v2_local_alias_routing.py new file mode 100644 index 0000000..f76deb3 --- /dev/null +++ b/server/tests/integration/test_v2_local_alias_routing.py @@ -0,0 +1,83 @@ +from app.config import get_settings + + +def _auth_header(token: str) -> dict[str, str]: + return {"Authorization": f"Bearer {token}"} + + +def _register_v2_user(client, username: str, password: str = "password123") -> dict: + response = client.post("/api/v2/auth/register", json={"username": username, "password": password}) + assert response.status_code == 201, response.text + return response.json() + + +def _register_v2_device(client, bootstrap_token: str, label: str) -> dict: + response = client.post( + "/api/v2/devices/register", + headers=_auth_header(bootstrap_token), + json={ + "label": label, + "pub_sign_key": f"{label}-sign-key-material-0123456789abcdef", + "pub_dh_key": f"{label}-dh-key-material-0123456789abcdef", + }, + ) + assert response.status_code == 200, response.text + return response.json() + + +def test_v2_create_dm_peer_address_alias_routes_local(client) -> None: + settings = get_settings() + previous_aliases = settings.local_server_aliases + settings.local_server_aliases = "localhost:8000,192.168.1.50:8000" + try: + alice = _register_v2_user(client, "alice_alias_dm") + bob = _register_v2_user(client, "bob_alias_dm") + alice_device = _register_v2_device(client, alice["tokens"]["bootstrap_token"], "alice-device") + _register_v2_device(client, bob["tokens"]["bootstrap_token"], "bob-device") + + response = client.post( + "/api/v2/conversations/dm", + headers=_auth_header(alice_device["tokens"]["access_token"]), + json={"peer_address": "bob_alias_dm@localhost:8000"}, + ) + finally: + settings.local_server_aliases = previous_aliases + + assert response.status_code == 200, response.text + body = response.json() + assert body["kind"] == "local" + assert body["peer_username"] == "bob_alias_dm" + assert body["peer_address"] == "bob_alias_dm@local.invalid" + + +def test_v2_resolve_devices_alias_routes_local_without_federation(client, monkeypatch) -> None: + settings = get_settings() + previous_aliases = settings.local_server_aliases + settings.local_server_aliases = "localhost:8000,192.168.1.50:8000" + + async def _unexpected_remote_lookup(_peer_onion: str, _username: str): + raise AssertionError("remote federation lookup should not be used for local aliases") + + monkeypatch.setattr( + "app.services.federation_client.federation_client.get_remote_user_devices_v2", + _unexpected_remote_lookup, + ) + + try: + alice = _register_v2_user(client, "alice_alias_resolve") + bob = _register_v2_user(client, "bob_alias_resolve") + alice_device = _register_v2_device(client, alice["tokens"]["bootstrap_token"], "alice-device") + _register_v2_device(client, bob["tokens"]["bootstrap_token"], "bob-device") + + response = client.get( + "/api/v2/users/resolve-devices", + headers=_auth_header(alice_device["tokens"]["access_token"]), + params={"peer_address": "bob_alias_resolve@localhost:8000"}, + ) + finally: + settings.local_server_aliases = previous_aliases + + assert response.status_code == 200, response.text + body = response.json() + assert body["username"] == "bob_alias_resolve" + assert body["attachment_policy_source"] == "local" diff --git a/server/tests/integration/test_v2_presence.py b/server/tests/integration/test_v2_presence.py new file mode 100644 index 0000000..5a1aa7f --- /dev/null +++ b/server/tests/integration/test_v2_presence.py @@ -0,0 +1,60 @@ +from tests.integration.test_v2_security_core import ( + _auth_header, + _new_device_material, + _register_device_v2, + _register_v2_user, +) + + +def test_v2_presence_set_and_resolve(client) -> None: + alice_register = _register_v2_user(client, "alice_v2_presence") + bob_register = _register_v2_user(client, "bob_v2_presence") + + alice_device = _new_device_material("alice-presence-device") + bob_device = _new_device_material("bob-presence-device") + + alice_tokens = _register_device_v2(client, alice_register["tokens"]["bootstrap_token"], alice_device)["tokens"] + bob_tokens = _register_device_v2(client, bob_register["tokens"]["bootstrap_token"], bob_device)["tokens"] + + alice_address = "alice_v2_presence@local.invalid" + + offline_resolve = client.post( + "/api/v2/presence/resolve", + headers=_auth_header(bob_tokens["access_token"]), + json={"peer_addresses": [alice_address]}, + ) + assert offline_resolve.status_code == 200, offline_resolve.text + assert offline_resolve.json()["peers"][0]["status"] == "offline" + + with client.websocket_connect("/api/v2/ws", headers=_auth_header(alice_tokens["access_token"])): + set_dnd = client.post( + "/api/v2/presence/set", + headers=_auth_header(alice_tokens["access_token"]), + json={"status": "dnd"}, + ) + assert set_dnd.status_code == 200, set_dnd.text + assert set_dnd.json()["status"] == "dnd" + + resolve_dnd = client.post( + "/api/v2/presence/resolve", + headers=_auth_header(bob_tokens["access_token"]), + json={"peer_addresses": [alice_address]}, + ) + assert resolve_dnd.status_code == 200, resolve_dnd.text + assert resolve_dnd.json()["peers"][0]["status"] == "dnd" + + set_inactive = client.post( + "/api/v2/presence/set", + headers=_auth_header(alice_tokens["access_token"]), + json={"status": "inactive"}, + ) + assert set_inactive.status_code == 200, set_inactive.text + assert set_inactive.json()["status"] == "inactive" + + resolve_inactive = client.post( + "/api/v2/presence/resolve", + headers=_auth_header(bob_tokens["access_token"]), + json={"peer_addresses": [alice_address]}, + ) + assert resolve_inactive.status_code == 200, resolve_inactive.text + assert resolve_inactive.json()["peers"][0]["status"] == "inactive" diff --git a/server/tests/integration/test_v2_security_core.py b/server/tests/integration/test_v2_security_core.py new file mode 100644 index 0000000..71cf0d2 --- /dev/null +++ b/server/tests/integration/test_v2_security_core.py @@ -0,0 +1,337 @@ +import base64 +import hashlib +import time +import uuid +from datetime import datetime +from typing import Any + +from nacl import encoding, public, signing + +from app.services.auth_service_v2 import canonical_bind_device_string +from app.services.message_service_v2 import canonical_message_signature_string +from app.services.prekey_service_v2 import canonical_signed_prekey_string + + +def _auth_header(token: str) -> dict[str, str]: + return {"Authorization": f"Bearer {token}"} + + +def _b64(data: bytes) -> str: + return base64.b64encode(data).decode("ascii") + + +def _sha256_hex(data: bytes) -> str: + return hashlib.sha256(data).hexdigest() + + +def _new_device_material(label: str) -> dict[str, Any]: + sign_sk = signing.SigningKey.generate() + sign_pk_b64 = sign_sk.verify_key.encode(encoder=encoding.Base64Encoder).decode("utf-8") + dh_sk = public.PrivateKey.generate() + dh_pk_b64 = _b64(bytes(dh_sk.public_key)) + return { + "label": label, + "sign_sk": sign_sk, + "sign_pk_b64": sign_pk_b64, + "dh_pk_b64": dh_pk_b64, + } + + +def _register_v2_user(client, username: str, password: str = "password123") -> dict[str, Any]: + response = client.post("/api/v2/auth/register", json={"username": username, "password": password}) + assert response.status_code == 201, response.text + return response.json() + + +def _login_v2_user(client, username: str, password: str = "password123") -> dict[str, Any]: + response = client.post("/api/v2/auth/login", json={"username": username, "password": password}) + assert response.status_code == 200, response.text + return response.json() + + +def _register_device_v2(client, bootstrap_token: str, material: dict[str, Any]) -> dict[str, Any]: + response = client.post( + "/api/v2/devices/register", + headers=_auth_header(bootstrap_token), + json={ + "label": material["label"], + "pub_sign_key": material["sign_pk_b64"], + "pub_dh_key": material["dh_pk_b64"], + }, + ) + assert response.status_code == 200, response.text + return response.json() + + +def _aggregate_chain_hash(sender_prev_hash: str, client_message_id: str, sent_at_ms: int, hash_material: list[str]) -> str: + aggregate = _sha256_hex("|".join(sorted(hash_material)).encode("utf-8")) + chain_data = "\n".join([sender_prev_hash, client_message_id, str(sent_at_ms), aggregate]).encode("utf-8") + return _sha256_hex(chain_data) + + +def test_v2_bootstrap_bind_refresh_and_revoke(client) -> None: + register_body = _register_v2_user(client, "alice_v2_auth") + bootstrap = register_body["tokens"]["bootstrap_token"] + + device1 = _new_device_material("alice-laptop") + register_device_body = _register_device_v2(client, bootstrap, device1) + tokens = register_device_body["tokens"] + device_uid = tokens["device_uid"] + + login_body = _login_v2_user(client, "alice_v2_auth") + bootstrap_again = login_body["tokens"]["bootstrap_token"] + + nonce = "bindnonce123" + timestamp_ms = int(time.time() * 1000) + canonical = canonical_bind_device_string( + register_body["user"]["id"], + device_uid, + nonce, + timestamp_ms, + ) + proof_sig = device1["sign_sk"].sign(canonical).signature + + bind = client.post( + "/api/v2/auth/bind-device", + headers=_auth_header(bootstrap_again), + json={ + "device_uid": device_uid, + "nonce": nonce, + "timestamp_ms": timestamp_ms, + "proof_signature_b64": _b64(proof_sig), + }, + ) + assert bind.status_code == 200, bind.text + bound_tokens = bind.json()["tokens"] + assert bound_tokens["device_uid"] == device_uid + + refreshed = client.post( + "/api/v2/auth/refresh", + json={"refresh_token": bound_tokens["refresh_token"]}, + ) + assert refreshed.status_code == 200, refreshed.text + assert refreshed.json()["tokens"]["device_uid"] == device_uid + + revoke = client.post( + f"/api/v2/devices/{device_uid}/revoke", + headers=_auth_header(refreshed.json()["tokens"]["access_token"]), + ) + assert revoke.status_code == 200, revoke.text + assert revoke.json()["status"] == "revoked" + + me_after_revoke = client.get( + "/api/v2/me", + headers=_auth_header(refreshed.json()["tokens"]["access_token"]), + ) + assert me_after_revoke.status_code == 401 + + +def test_v2_multi_device_signed_fanout_and_ack(client) -> None: + alice_register = _register_v2_user(client, "alice_v2_msg") + bob_register = _register_v2_user(client, "bob_v2_msg") + + alice_device_1 = _new_device_material("alice-laptop") + alice_device_2 = _new_device_material("alice-phone") + bob_device_1 = _new_device_material("bob-laptop") + + alice_tokens_1 = _register_device_v2(client, alice_register["tokens"]["bootstrap_token"], alice_device_1)["tokens"] + bob_tokens_1 = _register_device_v2(client, bob_register["tokens"]["bootstrap_token"], bob_device_1)["tokens"] + + alice_login = _login_v2_user(client, "alice_v2_msg") + alice_tokens_2 = _register_device_v2(client, alice_login["tokens"]["bootstrap_token"], alice_device_2)["tokens"] + + alice_access = alice_tokens_1["access_token"] + bob_access = bob_tokens_1["access_token"] + alice_device_1_uid = alice_tokens_1["device_uid"] + alice_device_2_uid = alice_tokens_2["device_uid"] + + dm = client.post( + "/api/v2/conversations/dm", + headers=_auth_header(alice_access), + json={"peer_username": "bob_v2_msg"}, + ) + assert dm.status_code == 200, dm.text + conversation_id = dm.json()["id"] + + resolved_peer = client.get( + "/api/v2/users/resolve-devices", + headers=_auth_header(alice_access), + params={"peer_address": "bob_v2_msg@local.invalid"}, + ) + assert resolved_peer.status_code == 200, resolved_peer.text + bob_devices = resolved_peer.json()["devices"] + assert len(bob_devices) == 1 + + alice_devices = client.get("/api/v2/devices", headers=_auth_header(alice_access)) + assert alice_devices.status_code == 200, alice_devices.text + active_alice_other = [d for d in alice_devices.json() if d["device_uid"] == alice_device_2_uid] + assert len(active_alice_other) == 1 + + client_message_id = str(uuid.uuid4()) + sent_at_ms = int(time.time() * 1000) + sender_prev_hash = "" + sender_sign_pk_b64 = alice_device_1["sign_pk_b64"] + + envelope_targets = [ + ("bob_v2_msg@local.invalid", bob_devices[0]["device_uid"], b"hello-bob"), + ("alice_v2_msg@local.invalid", alice_device_2_uid, b"self-mirror"), + ] + + envelopes = [] + hash_material: list[str] = [] + for recipient_address, recipient_device_uid, plaintext in envelope_targets: + ciphertext_b64 = _b64(plaintext) + ciphertext_hash = _sha256_hex(plaintext) + aad_hash = _sha256_hex(b"") + hash_material.append(f"{recipient_device_uid}:{ciphertext_hash}:{aad_hash}") + envelopes.append( + { + "recipient_user_address": recipient_address, + "recipient_device_uid": recipient_device_uid, + "ciphertext_b64": ciphertext_b64, + "aad_b64": None, + "signature_b64": "", + "sender_device_pubkey": sender_sign_pk_b64, + } + ) + + sender_chain_hash = _aggregate_chain_hash(sender_prev_hash, client_message_id, sent_at_ms, hash_material) + + for envelope in envelopes: + canonical = canonical_message_signature_string( + sender_address="alice_v2_msg@local.invalid", + sender_device_uid=alice_device_1_uid, + recipient_user_address=envelope["recipient_user_address"], + recipient_device_uid=envelope["recipient_device_uid"], + client_message_id=client_message_id, + sent_at_ms=sent_at_ms, + sender_prev_hash=sender_prev_hash, + sender_chain_hash=sender_chain_hash, + ciphertext_hash=_sha256_hex(base64.b64decode(envelope["ciphertext_b64"])), + aad_hash=_sha256_hex(b""), + ) + envelope["signature_b64"] = _b64(alice_device_1["sign_sk"].sign(canonical).signature) + + payload = { + "conversation_id": conversation_id, + "client_message_id": client_message_id, + "sent_at_ms": sent_at_ms, + "sender_prev_hash": sender_prev_hash, + "sender_chain_hash": sender_chain_hash, + "envelopes": envelopes, + } + + with client.websocket_connect("/api/v2/ws", headers=_auth_header(bob_access)) as bob_ws: + with client.websocket_connect("/api/v2/ws", headers=_auth_header(alice_tokens_2["access_token"])) as alice_ws: + send = client.post("/api/v2/messages/send", headers=_auth_header(alice_access), json=payload) + assert send.status_code == 200, send.text + assert send.json()["duplicate"] is False + + bob_event = bob_ws.receive_json() + assert bob_event["type"] == "message.new" + assert bob_event["message"]["sender_device_uid"] == alice_device_1_uid + bob_ws.send_json({"type": "message.ack", "copy_id": bob_event["copy_id"]}) + + alice_mirror_event = alice_ws.receive_json() + assert alice_mirror_event["type"] == "message.new" + assert alice_mirror_event["message"]["sender_device_uid"] == alice_device_1_uid + alice_ws.send_json({"type": "message.ack", "copy_id": alice_mirror_event["copy_id"]}) + + duplicate = client.post("/api/v2/messages/send", headers=_auth_header(alice_access), json=payload) + assert duplicate.status_code == 200 + assert duplicate.json()["duplicate"] is True + + +def test_v2_prekey_upload_and_resolve(client) -> None: + alice_register = _register_v2_user(client, "alice_v2_prekeys") + bob_register = _register_v2_user(client, "bob_v2_prekeys") + + alice_device = _new_device_material("alice-prekey-device") + bob_device = _new_device_material("bob-prekey-device") + + alice_tokens = _register_device_v2(client, alice_register["tokens"]["bootstrap_token"], alice_device)["tokens"] + bob_tokens = _register_device_v2(client, bob_register["tokens"]["bootstrap_token"], bob_device)["tokens"] + + expires_at = "2030-01-01T00:00:00+00:00" + signed_prekey_pub = alice_device["dh_pk_b64"] + canonical = canonical_signed_prekey_string( + alice_tokens["device_uid"], + 1, + signed_prekey_pub, + datetime.fromisoformat(expires_at), + ) + signature_b64 = _b64(alice_device["sign_sk"].sign(canonical).signature) + + upload = client.post( + "/api/v2/keys/prekeys/upload", + headers=_auth_header(alice_tokens["access_token"]), + json={ + "signed_prekey": { + "key_id": 1, + "pub_x25519_b64": signed_prekey_pub, + "sig_by_device_sign_key_b64": signature_b64, + "expires_at": expires_at, + }, + "one_time_prekeys": [], + }, + ) + assert upload.status_code == 200, upload.text + + resolve = client.get( + "/api/v2/users/resolve-prekeys", + headers=_auth_header(bob_tokens["access_token"]), + params={"peer_address": "alice_v2_prekeys@local.invalid"}, + ) + assert resolve.status_code == 200, resolve.text + body = resolve.json() + assert body["username"] == "alice_v2_prekeys" + assert len(body["devices"]) == 1 + assert body["devices"][0]["signed_prekey"]["key_id"] == 1 + assert body["devices"][0]["signed_prekey"]["pub_x25519_b64"] == signed_prekey_pub + + +def test_v2_ratchet_send_rejected_when_feature_disabled(client) -> None: + alice_register = _register_v2_user(client, "alice_v2_ratchet_off") + bob_register = _register_v2_user(client, "bob_v2_ratchet_off") + alice_device = _new_device_material("alice-ratchet-off") + bob_device = _new_device_material("bob-ratchet-off") + + alice_tokens = _register_device_v2(client, alice_register["tokens"]["bootstrap_token"], alice_device)["tokens"] + _register_device_v2(client, bob_register["tokens"]["bootstrap_token"], bob_device) + + dm = client.post( + "/api/v2/conversations/dm", + headers=_auth_header(alice_tokens["access_token"]), + json={"peer_username": "bob_v2_ratchet_off"}, + ) + assert dm.status_code == 200, dm.text + conversation_id = dm.json()["id"] + + client_message_id = str(uuid.uuid4()) + sent_at_ms = int(time.time() * 1000) + payload = { + "conversation_id": conversation_id, + "encryption_mode": "ratchet_v0_2b1", + "client_message_id": client_message_id, + "sent_at_ms": sent_at_ms, + "sender_prev_hash": "", + "sender_chain_hash": "f" * 64, + "envelopes": [ + { + "recipient_user_address": "bob_v2_ratchet_off@local.invalid", + "recipient_device_uid": "missing-device", + "ciphertext_b64": _b64(b"hello"), + "aad_b64": None, + "signature_b64": _b64(b"x" * 64), + "sender_device_pubkey": alice_device["sign_pk_b64"], + "ratchet_header": {"v": "dr_v1", "dh_pub": alice_device["dh_pk_b64"], "n": 0, "pn": 0}, + } + ], + } + response = client.post( + "/api/v2/messages/send", + headers=_auth_header(alice_tokens["access_token"]), + json=payload, + ) + assert response.status_code == 400 + assert "disabled" in response.text.lower() diff --git a/server/tests/integration/test_v2_webrtc_signaling.py b/server/tests/integration/test_v2_webrtc_signaling.py new file mode 100644 index 0000000..7875f94 --- /dev/null +++ b/server/tests/integration/test_v2_webrtc_signaling.py @@ -0,0 +1,84 @@ +import base64 +from typing import Any + +from nacl import encoding, public, signing + + +def _auth_header(token: str) -> dict[str, str]: + return {"Authorization": f"Bearer {token}"} + + +def _b64(data: bytes) -> str: + return base64.b64encode(data).decode("ascii") + + +def _new_device_material(label: str) -> dict[str, Any]: + sign_sk = signing.SigningKey.generate() + sign_pk_b64 = sign_sk.verify_key.encode(encoder=encoding.Base64Encoder).decode("utf-8") + dh_sk = public.PrivateKey.generate() + dh_pk_b64 = _b64(bytes(dh_sk.public_key)) + return { + "label": label, + "sign_pk_b64": sign_pk_b64, + "dh_pk_b64": dh_pk_b64, + } + + +def _register_user_v2(client, username: str) -> dict[str, Any]: + response = client.post("/api/v2/auth/register", json={"username": username, "password": "password123"}) + assert response.status_code == 201, response.text + return response.json() + + +def _register_device_v2(client, bootstrap_token: str, material: dict[str, Any]) -> dict[str, Any]: + response = client.post( + "/api/v2/devices/register", + headers=_auth_header(bootstrap_token), + json={ + "label": material["label"], + "pub_sign_key": material["sign_pk_b64"], + "pub_dh_key": material["dh_pk_b64"], + }, + ) + assert response.status_code == 200, response.text + return response.json() + + +def test_v2_ws_webrtc_offer_disabled_returns_call_error(client) -> None: + alice_reg = _register_user_v2(client, "alice_v2_webrtc_off") + bob_reg = _register_user_v2(client, "bob_v2_webrtc_off") + + alice_device = _new_device_material("alice-v2-webrtc-device") + bob_device = _new_device_material("bob-v2-webrtc-device") + alice_tokens = _register_device_v2(client, alice_reg["tokens"]["bootstrap_token"], alice_device)["tokens"] + bob_tokens = _register_device_v2(client, bob_reg["tokens"]["bootstrap_token"], bob_device)["tokens"] + + dm = client.post( + "/api/v2/conversations/dm", + headers=_auth_header(alice_tokens["access_token"]), + json={"peer_username": "bob_v2_webrtc_off"}, + ) + assert dm.status_code == 200, dm.text + conversation_id = dm.json()["id"] + + with client.websocket_connect("/api/v2/ws", headers=_auth_header(alice_tokens["access_token"])) as alice_ws: + with client.websocket_connect("/api/v2/ws", headers=_auth_header(bob_tokens["access_token"])) as bob_ws: + alice_ws.send_json({"type": "call.offer", "conversation_id": conversation_id}) + incoming = bob_ws.receive_json() + assert incoming["type"] == "call.incoming" + call_id = incoming["call_id"] + assert alice_ws.receive_json()["type"] == "call.ringing" + bob_ws.send_json({"type": "call.accept", "call_id": call_id}) + assert alice_ws.receive_json()["type"] == "call.accepted" + assert bob_ws.receive_json()["type"] == "call.accepted" + + alice_ws.send_json( + { + "type": "call.webrtc.offer", + "call_id": call_id, + "sdp": "v=0\r\no=- 0 0 IN IP4 127.0.0.1\r\ns=Blackwire\r\nt=0 0\r\nm=audio 9 RTP/AVP 0\r\n", + } + ) + err = alice_ws.receive_json() + assert err["type"] == "call.error" + assert err["code"] == "webrtc_disabled" diff --git a/server/tests/integration/test_ws_auth.py b/server/tests/integration/test_ws_auth.py new file mode 100644 index 0000000..0e80d94 --- /dev/null +++ b/server/tests/integration/test_ws_auth.py @@ -0,0 +1,12 @@ +from tests.helpers import register_user + + +def test_websocket_accepts_query_access_token(client) -> None: + alice = register_user(client, "alice_ws_query_auth") + access_token = alice["tokens"]["access_token"] + + with client.websocket_connect(f"/api/v1/ws?access_token={access_token}") as ws: + ws.send_json({"type": "unsupported"}) + response = ws.receive_json() + assert response["type"] == "error" + assert response["code"] == "unsupported_event" diff --git a/server/tests/unit/test_attachment_policy_config.py b/server/tests/unit/test_attachment_policy_config.py new file mode 100644 index 0000000..2455e76 --- /dev/null +++ b/server/tests/unit/test_attachment_policy_config.py @@ -0,0 +1,37 @@ +import pytest + +from app.config import Settings + + +def test_attachment_policy_defaults_match_v02c_plan() -> None: + settings = Settings( + environment="test", + database_url="sqlite+aiosqlite:///./test.db", + auto_create_tables=False, + attachment_inline_max_bytes=10 * 1024 * 1024, + attachment_hard_ceiling_bytes=32 * 1024 * 1024, + ) + assert settings.effective_attachment_inline_max_bytes() == 10 * 1024 * 1024 + assert settings.attachment_hard_ceiling_bytes == 32 * 1024 * 1024 + + +def test_attachment_policy_rejects_ciphertext_above_hard_ceiling() -> None: + with pytest.raises(ValueError): + Settings( + environment="test", + database_url="sqlite+aiosqlite:///./test.db", + auto_create_tables=False, + max_ciphertext_bytes=33 * 1024 * 1024, + attachment_hard_ceiling_bytes=32 * 1024 * 1024, + ) + + +def test_attachment_policy_rejects_client_body_smaller_than_ciphertext_limit() -> None: + with pytest.raises(ValueError): + Settings( + environment="test", + database_url="sqlite+aiosqlite:///./test.db", + auto_create_tables=False, + max_ciphertext_bytes=1024, + max_client_message_body_bytes=512, + ) diff --git a/server/tests/unit/test_rate_limit_weighted.py b/server/tests/unit/test_rate_limit_weighted.py new file mode 100644 index 0000000..b84719f --- /dev/null +++ b/server/tests/unit/test_rate_limit_weighted.py @@ -0,0 +1,13 @@ +import pytest +from fastapi import HTTPException + +from app.services.rate_limit import RateLimiter + + +@pytest.mark.asyncio +async def test_weighted_rate_limit_blocks_when_units_exceed_limit() -> None: + limiter = RateLimiter() + await limiter.enforce_weighted("weighted:test", units=5, limit=10) + with pytest.raises(HTTPException) as exc: + await limiter.enforce_weighted("weighted:test", units=6, limit=10) + assert exc.value.status_code == 429 diff --git a/server/tests/unit/test_server_authority.py b/server/tests/unit/test_server_authority.py new file mode 100644 index 0000000..85c67db --- /dev/null +++ b/server/tests/unit/test_server_authority.py @@ -0,0 +1,18 @@ +from app.config import Settings +from app.services.server_authority import authority_matches, is_local_server_authority + + +def test_authority_matches_ignores_missing_port() -> None: + assert authority_matches("localhost:8000", "localhost") + assert authority_matches("127.0.0.1", "127.0.0.1:8000") + + +def test_is_local_server_authority_uses_configured_aliases() -> None: + settings = Settings( + environment="test", + database_url="sqlite+aiosqlite:///./test.db", + auto_create_tables=False, + federation_server_onion="local.invalid", + local_server_aliases="192.168.1.50:8000", + ) + assert is_local_server_authority("192.168.1.50:8000", settings) diff --git a/server/tests/unit/test_v2_schema_types.py b/server/tests/unit/test_v2_schema_types.py new file mode 100644 index 0000000..8053e5e --- /dev/null +++ b/server/tests/unit/test_v2_schema_types.py @@ -0,0 +1,8 @@ +from sqlalchemy import BigInteger + +from app.models.message_event import MessageEvent + + +def test_message_event_sent_at_ms_is_bigint() -> None: + column_type = MessageEvent.__table__.c.sent_at_ms.type + assert isinstance(column_type, BigInteger)