From d276fe6a207984c87c068f720bcb86e01f0cdf83 Mon Sep 17 00:00:00 2001 From: Nick Johnson Date: Wed, 25 Jun 2025 14:35:38 -0700 Subject: [PATCH 1/5] feat!: compile-time safety for handshake Introducing the "TypeState" pattern to the handshake. This moves the runtime error variant, HandshakeOutOfOrder, to compile time. The type system now doesn't allow for out of order calls. --- protocol/benches/cipher_session.rs | 95 +-- protocol/fuzz/fuzz_targets/handshake.rs | 96 +-- protocol/src/handshake.rs | 905 +++++++++++++----------- protocol/src/io.rs | 108 ++- protocol/src/lib.rs | 239 ++++--- protocol/tests/round_trips.rs | 152 ++-- 6 files changed, 870 insertions(+), 725 deletions(-) diff --git a/protocol/benches/cipher_session.rs b/protocol/benches/cipher_session.rs index f206683..f9120ff 100644 --- a/protocol/benches/cipher_session.rs +++ b/protocol/benches/cipher_session.rs @@ -1,60 +1,69 @@ +// SPDX-License-Identifier: CC0-1.0 + #![feature(test)] extern crate test; -use bip324::{CipherSession, Handshake, InboundCipher, Network, OutboundCipher, PacketType, Role}; +use bip324::{ + CipherSession, Handshake, HandshakeAuthentication, InboundCipher, Initialized, Network, + OutboundCipher, PacketType, ReceivedKey, Role, NUM_INITIAL_HANDSHAKE_BUFFER_BYTES, +}; use test::{black_box, Bencher}; fn create_cipher_session_pair() -> (CipherSession, CipherSession) { - // Create a proper handshake between Alice and Bob. - let mut alice_init_buffer = vec![0u8; 64]; - let mut alice_handshake = Handshake::new( - Network::Bitcoin, - Role::Initiator, - None, - &mut alice_init_buffer, - ) - .unwrap(); - - let mut bob_init_buffer = vec![0u8; 100]; - let mut bob_handshake = Handshake::new( - Network::Bitcoin, - Role::Responder, - None, - &mut bob_init_buffer, - ) - .unwrap(); - - // Bob completes materials with Alice's key. - bob_handshake - .complete_materials( - alice_init_buffer[..64].try_into().unwrap(), - &mut bob_init_buffer[64..], - None, - ) + // Send Alice's key. + let alice_handshake = Handshake::::new(Network::Bitcoin, Role::Initiator).unwrap(); + let mut alice_key_buffer = vec![0u8; Handshake::::send_key_len(None)]; + let alice_handshake = alice_handshake + .send_key(None, &mut alice_key_buffer) .unwrap(); - // Alice completes materials with Bob's key. - let mut alice_response_buffer = vec![0u8; 36]; - alice_handshake - .complete_materials( - bob_init_buffer[..64].try_into().unwrap(), - &mut alice_response_buffer, - None, - ) + // Send Bob's key + let bob_handshake = Handshake::::new(Network::Bitcoin, Role::Responder).unwrap(); + let mut bob_key_buffer = vec![0u8; Handshake::::send_key_len(None)]; + let bob_handshake = bob_handshake.send_key(None, &mut bob_key_buffer).unwrap(); + + // Alice receives Bob's key. + let alice_handshake = alice_handshake + .receive_key(bob_key_buffer.try_into().unwrap()) .unwrap(); - // Authenticate. - let mut packet_buffer = vec![0u8; 4096]; - alice_handshake - .authenticate_garbage_and_version(&bob_init_buffer[64..], &mut packet_buffer) + // Bob receives Alice's key. + let bob_handshake = bob_handshake + .receive_key(alice_key_buffer.try_into().unwrap()) .unwrap(); - bob_handshake - .authenticate_garbage_and_version(&alice_response_buffer, &mut packet_buffer) + + // Alice sends version. + let mut alice_version_buffer = vec![0u8; Handshake::::send_version_len(None)]; + let alice_handshake = alice_handshake + .send_version(&mut alice_version_buffer, None) + .unwrap(); + + // Bob sends version. + let mut bob_version_buffer = vec![0u8; Handshake::::send_version_len(None)]; + let bob_handshake = bob_handshake + .send_version(&mut bob_version_buffer, None) .unwrap(); - let alice = alice_handshake.finalize().unwrap(); - let bob = bob_handshake.finalize().unwrap(); + let mut packet_buffer = vec![0u8; NUM_INITIAL_HANDSHAKE_BUFFER_BYTES]; + + // Alice receives Bob's version. + let alice = match alice_handshake + .receive_version(&bob_version_buffer, &mut packet_buffer) + .unwrap() + { + HandshakeAuthentication::Complete { cipher, .. } => cipher, + HandshakeAuthentication::NeedMoreData(_) => panic!("Should have completed"), + }; + + // Bob receives Alice's version. + let bob = match bob_handshake + .receive_version(&alice_version_buffer, &mut packet_buffer) + .unwrap() + { + HandshakeAuthentication::Complete { cipher, .. } => cipher, + HandshakeAuthentication::NeedMoreData(_) => panic!("Should have completed"), + }; (alice, bob) } diff --git a/protocol/fuzz/fuzz_targets/handshake.rs b/protocol/fuzz/fuzz_targets/handshake.rs index bd8c5c9..09b9ee6 100644 --- a/protocol/fuzz/fuzz_targets/handshake.rs +++ b/protocol/fuzz/fuzz_targets/handshake.rs @@ -1,58 +1,60 @@ +// SPDX-License-Identifier: CC0-1.0 + +//! ## Expected Outcomes +//! +//! * Most runs will fail with invalid EC points or handshake failures. +//! * No panics, crashes, or memory safety issues should occur. +//! * The implementation should handle all inputs gracefully. + #![no_main] -use bip324::{Handshake, Network, Role, NUM_INITIAL_HANDSHAKE_BUFFER_BYTES}; +use bip324::{ + Handshake, HandshakeAuthentication, Initialized, Network, ReceivedKey, Role, + NUM_INITIAL_HANDSHAKE_BUFFER_BYTES, +}; use libfuzzer_sys::fuzz_target; fuzz_target!(|data: &[u8]| { - // Skip if data is too small for an interesting test. - if data.len() < 64 { + // Skip if data is too small for a meaningful test. + // We need at least 64 bytes for the public key and at + // least 30 more for some interesting garbage, decoy, version bytes. + if data.len() < 100 { return; } - let mut initiator_pubkey = [0u8; 64]; - let mut handshake = Handshake::new( - Network::Bitcoin, - Role::Initiator, - None, - &mut initiator_pubkey, - ) - .unwrap(); - - let mut responder_pubkey = [0u8; 64]; - let _responder_handshake = Handshake::new( - Network::Bitcoin, - Role::Responder, - None, - &mut responder_pubkey, - ) - .unwrap(); - - // Create a mutation of the responder's bytes. - let mut garbage_and_version = [0u8; 36]; - let copy_len = std::cmp::min(data.len(), garbage_and_version.len()); - garbage_and_version[..copy_len].copy_from_slice(&data[..copy_len]); - - // Create mutation of the responder's public key. - // The key is either completely random or slightly tweaked. + // Initiator side of the handshake. + let handshake = Handshake::::new(Network::Bitcoin, Role::Initiator).unwrap(); + let mut initiator_pubkey = vec![0u8; Handshake::::send_key_len(None)]; + let handshake = handshake.send_key(None, &mut initiator_pubkey).unwrap(); + + // Use the first 64 bytes of fuzz data as the responder's public key. let mut fuzzed_responder_pubkey = [0u8; 64]; - if data.len() >= 128 { - fuzzed_responder_pubkey.copy_from_slice(&data[64..128]); - } else { - fuzzed_responder_pubkey.copy_from_slice(&responder_pubkey); - for (i, b) in data - .iter() - .enumerate() - .take(fuzzed_responder_pubkey.len()) - .skip(copy_len) - { - fuzzed_responder_pubkey[i % 64] ^= b; // XOR to make controlled changes. + fuzzed_responder_pubkey.copy_from_slice(&data[..64]); + + // Attempt to receive the fuzzed key. + let handshake = match handshake.receive_key(fuzzed_responder_pubkey) { + Ok(h) => h, + Err(_) => return, // Invalid key rejected successfully. + }; + + // Send version message just to move the state of the handshake along. + let mut version_buffer = vec![0u8; Handshake::::send_version_len(None)]; + let handshake = handshake.send_version(&mut version_buffer, None).unwrap(); + + // Try to receive and authenticate the fuzzed garbage and version data. + let garbage_and_version = Vec::from(&data[64..]); + let mut packet_buffer = vec![0u8; NUM_INITIAL_HANDSHAKE_BUFFER_BYTES]; + match handshake.receive_version(&garbage_and_version, &mut packet_buffer) { + Ok(HandshakeAuthentication::Complete { .. }) => { + // Handshake completed successfully. + // This should only happen with some very lucky random bytes. + } + Ok(HandshakeAuthentication::NeedMoreData(_)) => { + // Handshake needs more ciphertext. + // This is an expected outcome for fuzzed inputs. + } + Err(_) => { + // Authentication or parsing failed. + // This is an expected outcome for fuzzed inputs. } } - - // Try to complete the materials and authenticate with the fuzzed key and data. - // Exercising malformed public key handling. - let _ = handshake.complete_materials(fuzzed_responder_pubkey, &mut garbage_and_version, None); - // Check how a broken handshake is handled. - let mut packet_buffer = vec![0u8; NUM_INITIAL_HANDSHAKE_BUFFER_BYTES]; // Initial buffer for decoy and version packets - let _ = handshake.authenticate_garbage_and_version(&garbage_and_version, &mut packet_buffer); - let _ = handshake.finalize(); }); diff --git a/protocol/src/handshake.rs b/protocol/src/handshake.rs index 51311e9..8976784 100644 --- a/protocol/src/handshake.rs +++ b/protocol/src/handshake.rs @@ -1,3 +1,14 @@ +// SPDX-License-Identifier: CC0-1.0 + +//! # BIP-324 V2 Transport Protocol Handshake +//! +//! 1. **Key Exchange**: Both peers generate and exchange public keys using ElligatorSwift encoding. +//! 2. **Garbage**: Optional garbage bytes are sent to obscure traffic patterns. +//! 3. **Decoy Packets**: Optional decoy packets can be sent to further obscure traffic patterns. +//! 4. **Version Authentication**: Version packets are exchanged to negotiate the protocol version for the channel. +//! 5. **Session Establishment**: The secure communication channel is ready for message exchange. +//! ``` + use bitcoin::{ key::Secp256k1, secp256k1::{ @@ -9,7 +20,7 @@ use bitcoin::{ use rand::Rng; use crate::{ - CipherSession, Error, OutboundCipher, PacketType, Role, SessionKeyMaterial, + CipherSession, Error, InboundCipher, OutboundCipher, PacketType, Role, SessionKeyMaterial, NUM_ELLIGATOR_SWIFT_BYTES, NUM_GARBAGE_TERMINTOR_BYTES, NUM_LENGTH_BYTES, VERSION_CONTENT, }; @@ -27,94 +38,91 @@ pub struct EcdhPoint { elligator_swift: ElligatorSwift, } +/// **Initial state** of the handshake state machine which holds local secret materials. +pub struct Initialized { + point: EcdhPoint, +} + +/// **Second state** after sending the local public key. +pub struct SentKey { + point: EcdhPoint, + bytes_written: usize, +} + +/// **Third state** after receiving the remote's public key and +/// generating the shared secret materials for the session. +pub struct ReceivedKey { + session_keys: SessionKeyMaterial, +} + +/// **Fourth state** after sending the version packet. +pub struct SentVersion { + cipher: CipherSession, + remote_garbage_terminator: [u8; NUM_GARBAGE_TERMINTOR_BYTES], + bytes_written: usize, +} + +/// Success variants for receive_version. +pub enum HandshakeAuthentication<'a> { + /// Successfully completed. + Complete { + cipher: CipherSession, + bytes_consumed: usize, + }, + /// Need more data - returns handshake for caller to retry with more ciphertext. + NeedMoreData(Handshake<'a, SentVersion>), +} + /// Handshake state-machine to establish the secret material in the communication channel. /// -/// A handshake is first initialized to create local materials needed to setup communication -/// channel between an *initiator* and a *responder*. The next step is to call `complete_materials` -/// no matter if initiator or responder, however the responder should already have the -/// necessary materials from their peers request. `complete_materials` creates the response -/// packet to be sent from each peer and `authenticate_garbage_and_version` is then used -/// to verify the handshake. Finally, the `finalized` method is used to consumer the handshake -/// and return a cipher session for further communication on the channel. -pub struct Handshake<'a> { +/// The handshake progresses through multiple states, enforcing the protocol sequence at compile time. +/// +/// 1. `Initialized` - Initial state with local secret materials. +/// 2. `SentKey` - After sending local public key and optional garbage. +/// 3. `ReceivedKey` - After receiving remote's public key. +/// 4. `SentVersion` - After sending local garbage terminator and version packet. +/// 5. Complete - After receiving and authenticating remote's garbage, garbage terminator, decoy packets, and version packet. +pub struct Handshake<'a, State> { /// Bitcoin network both peers are operating on. network: Network, /// Local role in the handshake, initiator or responder. role: Role, - /// Local point for key exchange. - point: EcdhPoint, - /// Optional garbage bytes to send along in handshake. + /// Optional local garbage bytes to send along in handshake. garbage: Option<&'a [u8]>, - /// Peers expected garbage terminator. - remote_garbage_terminator: Option<[u8; NUM_GARBAGE_TERMINTOR_BYTES]>, - /// Cipher session output. - cipher_session: Option, - /// Decrypted length for next packet. Store state between authentication attempts to avoid resetting ciphers. - current_packet_length_bytes: Option, - /// Processesed buffer index. Store state between authentication attempts to avoid resetting ciphers. - current_buffer_index: usize, + /// State-specific data. + state: State, } -impl<'a> Handshake<'a> { - /// Initialize a V2 transport handshake with a peer. - /// - /// # Arguments - /// - /// * `network` - The bitcoin network which both peers operate on. - /// * `garbage` - Optional garbage to send in handshake. - /// * `buffer` - Packet buffer to send to peer which will include initial materials for handshake + garbage. - /// - /// # Returns - /// - /// An initialized handshake which must be finalized. - /// - /// # Errors - /// - /// Fails if their was an error generating the keypair. +// Methods available in all states +impl<'a, State> Handshake<'a, State> { + /// Get the network this handshake is operating on. + pub fn network(&self) -> Network { + self.network + } + + /// Get the local role in the handshake. + pub fn role(&self) -> Role { + self.role + } +} + +// Initialized state implementation +impl<'a> Handshake<'a, Initialized> { + /// Initialize a V2 transport handshake with a remote peer. #[cfg(feature = "std")] - pub fn new( - network: Network, - role: Role, - garbage: Option<&'a [u8]>, - buffer: &mut [u8], - ) -> Result { + pub fn new(network: Network, role: Role) -> Result { let mut rng = rand::thread_rng(); let curve = Secp256k1::signing_only(); - Self::new_with_rng(network, role, garbage, buffer, &mut rng, &curve) + Self::new_with_rng(network, role, &mut rng, &curve) } - /// Initialize a V2 transport handshake with a peer. - /// - /// # Arguments - /// - /// * `network` - The bitcoin network which both peers operate on. - /// * `garbage` - Optional garbage to send in handshake. - /// * `buffer` - Packet buffer to send to peer which will include initial materials for handshake + garbage. - /// * `rng` - Supplied Random Number Generator. - /// * `curve` - Supplied secp256k1 context. - /// - /// # Returns - /// - /// An initialized handshake which must be finalized. - /// - /// # Errors - /// - /// Fails if their was an error generating the keypair. + /// Initialize a V2 transport handshake with remote peer using supplied RNG and secp context. pub fn new_with_rng( network: Network, role: Role, - garbage: Option<&'a [u8]>, - buffer: &mut [u8], rng: &mut impl Rng, curve: &Secp256k1, ) -> Result { - if garbage - .as_ref() - .map_or(false, |g| g.len() > MAX_NUM_GARBAGE_BYTES) - { - return Err(Error::TooMuchGarbage); - }; - let mut secret_key_buffer = [0u8; 32]; rng.fill(&mut secret_key_buffer[..]); let sk = SecretKey::from_slice(&secret_key_buffer)?; @@ -126,54 +134,83 @@ impl<'a> Handshake<'a> { elligator_swift: es, }; - // Bounds check on the output buffer. - let required_bytes = garbage.map_or(NUM_ELLIGATOR_SWIFT_BYTES, |g| { - NUM_ELLIGATOR_SWIFT_BYTES + g.len() - }); - if buffer.len() < required_bytes { - return Err(Error::BufferTooSmall { required_bytes }); - }; + Ok(Handshake { + network, + role, + garbage: None, + state: Initialized { point }, + }) + } + + /// Calculate how many bytes send_key() will write to buffer. + pub fn send_key_len(garbage: Option<&[u8]>) -> usize { + NUM_ELLIGATOR_SWIFT_BYTES + garbage.map(|g| g.len()).unwrap_or(0) + } + + /// Send local public key and optional garbage. + pub fn send_key( + mut self, + garbage: Option<&'a [u8]>, + output_buffer: &mut [u8], + ) -> Result, Error> { + // Validate garbage length + if let Some(g) = garbage { + if g.len() > MAX_NUM_GARBAGE_BYTES { + return Err(Error::TooMuchGarbage); + } + } + + let required_len = Self::send_key_len(garbage); + if output_buffer.len() < required_len { + return Err(Error::BufferTooSmall { + required_bytes: required_len, + }); + } - buffer[0..64].copy_from_slice(&point.elligator_swift.to_array()); - if let Some(garbage) = garbage { - buffer[64..64 + garbage.len()].copy_from_slice(garbage); + // Write local ellswift public key. + output_buffer[..NUM_ELLIGATOR_SWIFT_BYTES] + .copy_from_slice(&self.state.point.elligator_swift.to_array()); + let mut written = NUM_ELLIGATOR_SWIFT_BYTES; + + // Write garbage if provided. + if let Some(g) = garbage { + output_buffer[written..written + g.len()].copy_from_slice(g); + written += g.len(); } + // Store garbage for later use. + self.garbage = garbage; + Ok(Handshake { - network, - role, - point, - garbage, - remote_garbage_terminator: None, - cipher_session: None, - current_packet_length_bytes: None, - current_buffer_index: 0, + network: self.network, + role: self.role, + garbage: self.garbage, + state: SentKey { + point: self.state.point, + bytes_written: written, + }, }) } +} - /// Complete the secret material handshake and send the version packet to peer. - /// - /// # Arguments - /// - /// * `their_elliswift` - The key material of the remote peer. - /// * `response_buffer` - Buffer to write response for remote peer which includes the garbage terminator and version packet. - /// * `decoys` - Contents for decoy packets sent before version packet. - /// - /// # Errors - /// - /// * `V1Protocol` - The remote is communicating on the V1 protocol instead of V2. Caller can fallback - /// to V1 if they want. - pub fn complete_materials( - &mut self, - their_elliswift: [u8; NUM_ELLIGATOR_SWIFT_BYTES], - response_buffer: &mut [u8], - decoys: Option<&[&[u8]]>, - ) -> Result<(), Error> { - // Short circuit if the remote is sending the V1 protocol network bytes. - // Gives the caller an opportunity to fallback to V1 if they choose. +// SentKey state implementation +impl<'a> Handshake<'a, SentKey> { + /// Get how many bytes were written by send_key(). + pub fn bytes_written(&self) -> usize { + self.state.bytes_written + } + + /// Process received key material. + pub fn receive_key( + self, + their_key: [u8; NUM_ELLIGATOR_SWIFT_BYTES], + ) -> Result, Error> { + let their_ellswift = ElligatorSwift::from_array(their_key); + + // Check for V1 protocol magic bytes if self.network.magic() == bitcoin::p2p::Magic::from_bytes( - their_elliswift[..4] + their_key[..4] .try_into() .expect("64 byte array to have 4 byte prefix"), ) @@ -181,380 +218,450 @@ impl<'a> Handshake<'a> { return Err(Error::V1Protocol); } - let theirs = ElligatorSwift::from_array(their_elliswift); + // Compute session keys using ECDH + let (initiator_ellswift, responder_ellswift, secret, party) = match self.role { + Role::Initiator => ( + self.state.point.elligator_swift, + their_ellswift, + self.state.point.secret_key, + ElligatorSwiftParty::A, + ), + Role::Responder => ( + their_ellswift, + self.state.point.elligator_swift, + self.state.point.secret_key, + ElligatorSwiftParty::B, + ), + }; + + let session_keys = SessionKeyMaterial::from_ecdh( + initiator_ellswift, + responder_ellswift, + secret, + party, + self.network, + )?; + + Ok(Handshake { + network: self.network, + role: self.role, + garbage: self.garbage, + state: ReceivedKey { session_keys }, + }) + } +} + +// ReceivedKey state implementation +impl<'a> Handshake<'a, ReceivedKey> { + /// Calculate how many bytes send_version() will write to buffer. + pub fn send_version_len(decoys: Option<&[&[u8]]>) -> usize { + let mut len = NUM_GARBAGE_TERMINTOR_BYTES + + OutboundCipher::encryption_buffer_len(VERSION_CONTENT.len()); + + // Add decoy packets length. + if let Some(decoys) = decoys { + for decoy in decoys { + len += OutboundCipher::encryption_buffer_len(decoy.len()); + } + } + + len + } - // Check if the buffer is large enough for the garbage terminator. - if response_buffer.len() < NUM_GARBAGE_TERMINTOR_BYTES { + /// Send garbage terminator, optional decoys, and version packet. + pub fn send_version( + self, + output_buffer: &mut [u8], + decoys: Option<&[&[u8]]>, + ) -> Result, Error> { + let required_len = Self::send_version_len(decoys); + if output_buffer.len() < required_len { return Err(Error::BufferTooSmall { - required_bytes: NUM_GARBAGE_TERMINTOR_BYTES, + required_bytes: required_len, }); } - // Line up appropriate materials based on role and some - // garbage terminator haggling. - let materials = match self.role { - Role::Initiator => { - let materials = SessionKeyMaterial::from_ecdh( - self.point.elligator_swift, - theirs, - self.point.secret_key, - ElligatorSwiftParty::A, - self.network, - )?; - response_buffer[..NUM_GARBAGE_TERMINTOR_BYTES] - .copy_from_slice(&materials.initiator_garbage_terminator); - self.remote_garbage_terminator = Some(materials.responder_garbage_terminator); + let mut cipher = CipherSession::new(self.state.session_keys.clone(), self.role); - materials + // Write garbage terminator and determine remote terminator. + let remote_garbage_terminator = match self.role { + Role::Initiator => { + output_buffer[..NUM_GARBAGE_TERMINTOR_BYTES] + .copy_from_slice(&self.state.session_keys.initiator_garbage_terminator); + self.state.session_keys.responder_garbage_terminator } Role::Responder => { - let materials = SessionKeyMaterial::from_ecdh( - theirs, - self.point.elligator_swift, - self.point.secret_key, - ElligatorSwiftParty::B, - self.network, - )?; - response_buffer[..NUM_GARBAGE_TERMINTOR_BYTES] - .copy_from_slice(&materials.responder_garbage_terminator); - self.remote_garbage_terminator = Some(materials.initiator_garbage_terminator); - - materials + output_buffer[..NUM_GARBAGE_TERMINTOR_BYTES] + .copy_from_slice(&self.state.session_keys.responder_garbage_terminator); + self.state.session_keys.initiator_garbage_terminator } }; - let mut cipher_session = CipherSession::new(materials, self.role); - let mut start_index = NUM_GARBAGE_TERMINTOR_BYTES; + let mut bytes_written = NUM_GARBAGE_TERMINTOR_BYTES; + // Local garbage is authenticated in first packet no + // matter if it is a decoy or genuine. + let mut aad = self.garbage; - // Write any decoy packets and then the version packet. - // The first packet, no matter if decoy or genuinie version packet, needs - // to authenticate the garbage previously sent. if let Some(decoys) = decoys { - for (i, decoy) in decoys.iter().enumerate() { - let end_index = start_index + OutboundCipher::encryption_buffer_len(decoy.len()); - cipher_session.outbound().encrypt( + for decoy in decoys { + let packet_len = OutboundCipher::encryption_buffer_len(decoy.len()); + cipher.outbound().encrypt( decoy, - &mut response_buffer[start_index..end_index], + &mut output_buffer[bytes_written..bytes_written + packet_len], PacketType::Decoy, - if i == 0 { self.garbage } else { None }, + aad, )?; - - start_index = end_index; + aad = None; + bytes_written += packet_len; } } - cipher_session.outbound().encrypt( + // Write version packet + let version_packet_len = OutboundCipher::encryption_buffer_len(VERSION_CONTENT.len()); + cipher.outbound().encrypt( &VERSION_CONTENT, - &mut response_buffer[start_index - ..start_index + OutboundCipher::encryption_buffer_len(VERSION_CONTENT.len())], + &mut output_buffer[bytes_written..bytes_written + version_packet_len], PacketType::Genuine, - if decoys.is_none() { self.garbage } else { None }, + aad, )?; + bytes_written += version_packet_len; - self.cipher_session = Some(cipher_session); + Ok(Handshake { + network: self.network, + role: self.role, + garbage: self.garbage, + state: SentVersion { + cipher, + remote_garbage_terminator, + bytes_written, + }, + }) + } +} - Ok(()) +// SentVersion state implementation +impl<'a> Handshake<'a, SentVersion> { + /// Get how many bytes were written by send_version(). + pub fn bytes_written(&self) -> usize { + self.state.bytes_written } - /// Authenticate the channel. - /// - /// Designed to be called multiple times until succesful in order to flush - /// garbage and decoy packets from channel. If a `BufferTooSmall ` is - /// returned, the buffer should be extended until `BufferTooSmall ` is - /// not returned. All other errors are fatal for the handshake and it should - /// be completely restarted. - /// - /// # Arguments - /// - /// * `buffer` - Should contain all garbage, the garbage terminator, any decoy packets, and finally the version packet received from peer. - /// * `packet_buffer` - Required memory allocation for decrypting decoy and version packets. - /// - /// # Error - /// - /// * `CiphertextTooSmall` - The buffer did not contain all required information and should be extended (e.g. read more off a socket) and authentication re-tried. - /// * `BufferTooSmall` - The supplied packet_buffer is not large enough for decrypting the decoy and version packets. - /// * `HandshakeOutOfOrder` - The handshake sequence is in a bad state and should be restarted. - /// * `MaxGarbageLength` - Buffer did not contain the garbage terminator, should not be retried. - pub fn authenticate_garbage_and_version( - &mut self, - buffer: &[u8], - packet_buffer: &mut [u8], - ) -> Result<(), Error> { - // Find the end of the garbage. - let (garbage, ciphertext) = self.split_garbage(buffer)?; - - // Flag to track if the version packet has been received to signal the end of the handshake. - let mut found_version_packet = false; - - // The first packet, even if it is a decoy packet, - // is used to authenticate the received garbage through - // the AAD. - if self.current_buffer_index == 0 { - found_version_packet = self.decrypt_packet(ciphertext, packet_buffer, Some(garbage))?; - } + /// Authenticate garbage and version packet. + pub fn receive_version( + mut self, + input_buffer: &[u8], + output_buffer: &mut [u8], + ) -> Result, Error> { + let (garbage, ciphertext) = match self.split_garbage(input_buffer) { + Ok(split) => split, + Err(Error::CiphertextTooSmall) => { + return Ok(HandshakeAuthentication::NeedMoreData(self)) + } + Err(e) => return Err(e), + }; - // If the first packet is a decoy, or if this is a follow up - // authentication attempt, the decoys need to be flushed and - // the version packet found. - // - // The version packet is essentially ignored in the current - // version of the protocol, but it does move the cipher - // states forward. It could be extended in the future. - while !found_version_packet { - found_version_packet = self.decrypt_packet(ciphertext, packet_buffer, None)?; + let mut aad = if garbage.is_empty() { + None + } else { + Some(garbage) + }; + let mut ciphertext_index = 0; + let mut found_version = false; + + // First packet authenticates remote garbage. + // Continue through decoys until we find version packet. + while !found_version { + match self.decrypt_packet(&ciphertext[ciphertext_index..], output_buffer, aad) { + Ok((packet_type, bytes_consumed)) => { + aad = None; + ciphertext_index += bytes_consumed; + found_version = matches!(packet_type, PacketType::Genuine); + } + Err(Error::CiphertextTooSmall) => { + return Ok(HandshakeAuthentication::NeedMoreData(self)) + } + Err(e) => return Err(e), + } } - Ok(()) + // Calculate total bytes consumed (garbage + terminator + packets) + let bytes_consumed = garbage.len() + NUM_GARBAGE_TERMINTOR_BYTES + ciphertext_index; + + Ok(HandshakeAuthentication::Complete { + cipher: self.state.cipher, + bytes_consumed, + }) + } + + /// Split buffer on garbage terminator. + fn split_garbage<'b>(&self, buffer: &'b [u8]) -> Result<(&'b [u8], &'b [u8]), Error> { + let terminator = &self.state.remote_garbage_terminator; + + if let Some(index) = buffer + .windows(terminator.len()) + .position(|window| window == terminator) + { + Ok((&buffer[..index], &buffer[index + terminator.len()..])) + } else if buffer.len() >= MAX_NUM_GARBAGE_BYTES + NUM_GARBAGE_TERMINTOR_BYTES { + Err(Error::NoGarbageTerminator) + } else { + Err(Error::CiphertextTooSmall) + } } - /// Decrypt the next packet in the buffer while - /// book keeping relevant lengths and indexes. This allows - /// the buffer to be re-processed without throwing off - /// the state of the ciphers. - /// - /// # Returns - /// - /// True if the decrypted packet is the version packet. fn decrypt_packet( &mut self, ciphertext: &[u8], - packet_buffer: &mut [u8], - garbage: Option<&[u8]>, - ) -> Result { - let cipher_session = self - .cipher_session - .as_mut() - .ok_or(Error::HandshakeOutOfOrder)?; - - if self.current_packet_length_bytes.is_none() { - // Bounds check on the input buffer. - if ciphertext.len() < self.current_buffer_index + NUM_LENGTH_BYTES { - return Err(Error::CiphertextTooSmall); - } - let packet_length = cipher_session.inbound().decrypt_packet_len( - ciphertext[self.current_buffer_index..self.current_buffer_index + NUM_LENGTH_BYTES] - .try_into() - .expect("Buffer slice must be exactly 3 bytes long"), - ); - // Hang on to decrypted length incase follow up steps fail - // and another authentication attempt is required. Avoids - // throwing off the cipher state. - self.current_packet_length_bytes = Some(packet_length); + output_buffer: &mut [u8], + aad: Option<&[u8]>, + ) -> Result<(PacketType, usize), Error> { + if ciphertext.len() < NUM_LENGTH_BYTES { + return Err(Error::CiphertextTooSmall); } - let packet_length = self - .current_packet_length_bytes - .ok_or(Error::HandshakeOutOfOrder)?; + let packet_len = self.state.cipher.inbound().decrypt_packet_len( + ciphertext[..NUM_LENGTH_BYTES] + .try_into() + .expect("Checked length above"), + ); - // Bounds check on input buffer. - if ciphertext.len() < self.current_buffer_index + NUM_LENGTH_BYTES + packet_length { + if ciphertext.len() < NUM_LENGTH_BYTES + packet_len { return Err(Error::CiphertextTooSmall); } - let packet_type = cipher_session.inbound().decrypt( - &ciphertext[self.current_buffer_index + NUM_LENGTH_BYTES - ..self.current_buffer_index + NUM_LENGTH_BYTES + packet_length], - packet_buffer, - garbage, - )?; - // Mark current decryption point in the buffer. - self.current_buffer_index = self.current_buffer_index + NUM_LENGTH_BYTES + packet_length; - self.current_packet_length_bytes = None; - - Ok(matches!(packet_type, PacketType::Genuine)) - } + // Check output buffer is large enough. + let plaintext_len = InboundCipher::decryption_buffer_len(packet_len); + if output_buffer.len() < plaintext_len { + return Err(Error::BufferTooSmall { + required_bytes: plaintext_len, + }); + } - /// Complete the handshake and return the cipher session for further communication. - /// - /// # Error - /// - /// * `HandshakeOutOfOrder` - The handshake sequence is in a bad state and should be restarted. - pub fn finalize(self) -> Result { - let cipher_session = self.cipher_session.ok_or(Error::HandshakeOutOfOrder)?; - Ok(cipher_session) - } + let packet_type = self.state.cipher.inbound().decrypt( + &ciphertext[NUM_LENGTH_BYTES..NUM_LENGTH_BYTES + packet_len], + &mut output_buffer[..plaintext_len], + aad, + )?; - /// Split off garbage in the given buffer on the remote garbage terminator. - /// - /// # Returns - /// - /// A `Result` containing the garbage and the remaining ciphertext not including the terminator. - /// - /// # Error - /// - /// * `CiphertextTooSmall` - Buffer did not contain a garbage terminator. - /// * `MaxGarbageLength` - Buffer did not contain the garbage terminator and contains too much garbage, should not be retried. - fn split_garbage<'b>(&self, buffer: &'b [u8]) -> Result<(&'b [u8], &'b [u8]), Error> { - let garbage_term = self - .remote_garbage_terminator - .ok_or(Error::HandshakeOutOfOrder)?; - if let Some(index) = buffer - .windows(garbage_term.len()) - .position(|window| window == garbage_term) - { - Ok((&buffer[..index], &buffer[(index + garbage_term.len())..])) - } else if buffer.len() >= (MAX_NUM_GARBAGE_BYTES + NUM_GARBAGE_TERMINTOR_BYTES) { - Err(Error::NoGarbageTerminator) - } else { - // Terminator not found, the buffer needs more information. - Err(Error::CiphertextTooSmall) - } + Ok((packet_type, NUM_LENGTH_BYTES + packet_len)) } } #[cfg(all(test, feature = "std"))] mod tests { - use bitcoin::secp256k1::ellswift::{ElligatorSwift, ElligatorSwiftParty}; - use core::str::FromStr; - use hex::prelude::*; - use std::{string::ToString, vec}; + use std::vec; use super::*; #[test] - fn test_initial_message() { - let mut message = [0u8; 64]; - let handshake = - Handshake::new(Network::Bitcoin, Role::Initiator, None, &mut message).unwrap(); - let message = message.to_lower_hex_string(); - let es = handshake.point.elligator_swift.to_string(); - assert!(message.contains(&es)) - } + fn test_handshake() { + let initiator_garbage = vec![1u8, 2u8, 3u8]; + let responder_garbage = vec![4u8, 5u8]; + + let init_handshake = + Handshake::::new(Network::Bitcoin, Role::Initiator).unwrap(); + + // Send initiator key + garbage. + let mut init_buffer = + vec![0u8; Handshake::::send_key_len(Some(&initiator_garbage))]; + let init_handshake = init_handshake + .send_key(Some(&initiator_garbage), &mut init_buffer) + .unwrap(); + assert_eq!(init_handshake.bytes_written(), 64 + initiator_garbage.len()); - #[test] - fn test_message_response() { - let mut message = [0u8; 64]; - Handshake::new(Network::Bitcoin, Role::Initiator, None, &mut message).unwrap(); - - let mut response_message = [0u8; 100]; - let mut response = Handshake::new( - Network::Bitcoin, - Role::Responder, - None, - &mut response_message, - ) - .unwrap(); - - response - .complete_materials(message, &mut response_message[64..], None) + let resp_handshake = + Handshake::::new(Network::Bitcoin, Role::Responder).unwrap(); + + // Send responder key + garbage. + let mut resp_buffer = + vec![0u8; Handshake::::send_key_len(Some(&responder_garbage))]; + let resp_handshake = resp_handshake + .send_key(Some(&responder_garbage), &mut resp_buffer) .unwrap(); - } + assert_eq!(resp_handshake.bytes_written(), 64 + responder_garbage.len()); - #[test] - fn test_shared_secret() { - // Test that SessionKeyMaterial::from_ecdh produces expected garbage terminators - let alice = - SecretKey::from_str("61062ea5071d800bbfd59e2e8b53d47d194b095ae5a4df04936b49772ef0d4d7") - .unwrap(); - let elliswift_alice = ElligatorSwift::from_str("ec0adff257bbfe500c188c80b4fdd640f6b45a482bbc15fc7cef5931deff0aa186f6eb9bba7b85dc4dcc28b28722de1e3d9108b985e2967045668f66098e475b").unwrap(); - let elliswift_bob = ElligatorSwift::from_str("a4a94dfce69b4a2a0a099313d10f9f7e7d649d60501c9e1d274c300e0d89aafaffffffffffffffffffffffffffffffffffffffffffffffffffffffff8faf88d5").unwrap(); - let session_keys = SessionKeyMaterial::from_ecdh( - elliswift_alice, - elliswift_bob, - alice, - ElligatorSwiftParty::A, - Network::Bitcoin, - ) - .unwrap(); - // Just verify the garbage terminators which are the only public fields we need - assert_eq!( - "faef555dfcdb936425d84aba524758f3", - session_keys - .initiator_garbage_terminator - .to_lower_hex_string() - ); - assert_eq!( - "02cb8ff24307a6e27de3b4e7ea3fa65b", - session_keys - .responder_garbage_terminator - .to_lower_hex_string() - ); + // Initiator receives responder's key. + let init_handshake = init_handshake + .receive_key(resp_buffer[..64].try_into().unwrap()) + .unwrap(); + + // Responder receives initiator's key. + let resp_handshake = resp_handshake + .receive_key(init_buffer[..64].try_into().unwrap()) + .unwrap(); + + // Create decoy packets for both sides. + let init_decoy1 = vec![0xDE, 0xAD, 0xBE, 0xEF]; + let init_decoy2 = vec![0xCA, 0xFE, 0xBA, 0xBE, 0x00, 0x01]; + let init_decoys = vec![init_decoy1.as_slice(), init_decoy2.as_slice()]; + + let resp_decoy1 = vec![0xAB, 0xCD, 0xEF]; + let resp_decoys = vec![resp_decoy1.as_slice()]; + + // Initiator sends decoys and version. + let mut init_version_buffer = + vec![0u8; Handshake::::send_version_len(Some(&init_decoys))]; + let init_handshake = init_handshake + .send_version(&mut init_version_buffer, Some(&init_decoys)) + .unwrap(); + + // Responder sends decoys and version. + let mut resp_version_buffer = + vec![0u8; Handshake::::send_version_len(Some(&resp_decoys))]; + let resp_handshake = resp_handshake + .send_version(&mut resp_version_buffer, Some(&resp_decoys)) + .unwrap(); + + let mut packet_buffer = vec![0u8; NUM_INITIAL_HANDSHAKE_BUFFER_BYTES]; + + // Initiator receives responder's garbage, decoys, and version. + let full_resp_message = [&responder_garbage[..], &resp_version_buffer[..]].concat(); + match init_handshake + .receive_version(&full_resp_message, &mut packet_buffer) + .unwrap() + { + HandshakeAuthentication::Complete { bytes_consumed, .. } => { + assert_eq!(bytes_consumed, full_resp_message.len()); + } + HandshakeAuthentication::NeedMoreData(_) => panic!("Should have completed"), + } + + // Responder receives initiator's garbage, decoys, and version. + let full_init_message = [&initiator_garbage[..], &init_version_buffer[..]].concat(); + match resp_handshake + .receive_version(&full_init_message, &mut packet_buffer) + .unwrap() + { + HandshakeAuthentication::Complete { bytes_consumed, .. } => { + assert_eq!(bytes_consumed, full_init_message.len()); + } + HandshakeAuthentication::NeedMoreData(_) => panic!("Should have completed"), + } } #[test] fn test_handshake_garbage_length_check() { - let mut rng = rand::thread_rng(); - let curve = Secp256k1::new(); - let mut handshake_buffer = [0u8; NUM_ELLIGATOR_SWIFT_BYTES + MAX_NUM_GARBAGE_BYTES]; - - // Test with valid garbage length. + // Test with valid garbage length let valid_garbage = vec![0u8; MAX_NUM_GARBAGE_BYTES]; - let result = Handshake::new_with_rng( - Network::Bitcoin, - Role::Initiator, - Some(&valid_garbage), - &mut handshake_buffer, - &mut rng, - &curve, - ); + let handshake = Handshake::::new(Network::Bitcoin, Role::Initiator).unwrap(); + let mut buffer = vec![0u8; NUM_ELLIGATOR_SWIFT_BYTES + MAX_NUM_GARBAGE_BYTES]; + let result = handshake.send_key(Some(&valid_garbage), &mut buffer); assert!(result.is_ok()); - // Test with garbage length exceeding MAX_NUM_GARBAGE_BYTES. + // Test with garbage length exceeding MAX_NUM_GARBAGE_BYTES let too_much_garbage = vec![0u8; MAX_NUM_GARBAGE_BYTES + 1]; - let result = Handshake::new_with_rng( - Network::Bitcoin, - Role::Initiator, - Some(&too_much_garbage), - &mut handshake_buffer, - &mut rng, - &curve, - ); + let handshake = Handshake::::new(Network::Bitcoin, Role::Initiator).unwrap(); + let result = handshake.send_key(Some(&too_much_garbage), &mut buffer); assert!(matches!(result, Err(Error::TooMuchGarbage))); - // Test too small of buffer. + // Test too small of buffer let buffer_size = NUM_ELLIGATOR_SWIFT_BYTES + valid_garbage.len() - 1; let mut too_small_buffer = vec![0u8; buffer_size]; - let result = Handshake::new_with_rng( - Network::Bitcoin, - Role::Initiator, - Some(&valid_garbage), - &mut too_small_buffer, - &mut rng, - &curve, - ); - + let handshake = Handshake::::new(Network::Bitcoin, Role::Initiator).unwrap(); + let result = handshake.send_key(Some(&valid_garbage), &mut too_small_buffer); assert!( matches!(result, Err(Error::BufferTooSmall { required_bytes }) if required_bytes == NUM_ELLIGATOR_SWIFT_BYTES + valid_garbage.len()), "Expected BufferTooSmall with correct size" ); - // Test with no garbage. - let result = Handshake::new_with_rng( - Network::Bitcoin, - Role::Initiator, - None, - &mut handshake_buffer, - &mut rng, - &curve, - ); + // Test with no garbage + let handshake = Handshake::::new(Network::Bitcoin, Role::Initiator).unwrap(); + let result = handshake.send_key(None, &mut buffer); assert!(result.is_ok()); } + #[test] + fn test_handshake_receive_version_buffer() { + // Test the scenario where receive_version needs more data. + let init_handshake = + Handshake::::new(Network::Bitcoin, Role::Initiator).unwrap(); + let resp_handshake = + Handshake::::new(Network::Bitcoin, Role::Responder).unwrap(); + + let mut init_buffer = vec![0u8; NUM_ELLIGATOR_SWIFT_BYTES]; + let init_handshake = init_handshake.send_key(None, &mut init_buffer).unwrap(); + + let mut resp_buffer = vec![0u8; NUM_ELLIGATOR_SWIFT_BYTES]; + let resp_handshake = resp_handshake.send_key(None, &mut resp_buffer).unwrap(); + + let init_handshake = init_handshake + .receive_key(resp_buffer[..64].try_into().unwrap()) + .unwrap(); + let resp_handshake = resp_handshake + .receive_key(init_buffer[..64].try_into().unwrap()) + .unwrap(); + + let mut init_version_buffer = vec![0u8; Handshake::::send_version_len(None)]; + let _init_handshake = init_handshake + .send_version(&mut init_version_buffer, None) + .unwrap(); + + let mut resp_version_buffer = vec![0u8; Handshake::::send_version_len(None)]; + let resp_handshake = resp_handshake + .send_version(&mut resp_version_buffer, None) + .unwrap(); + + let mut packet_buffer = vec![0u8; NUM_INITIAL_HANDSHAKE_BUFFER_BYTES]; + + // Feed data in very small chunks to trigger NeedMoreData. + let partial_data_1 = &init_version_buffer[..1]; + let returned_handshake = match resp_handshake + .receive_version(partial_data_1, &mut packet_buffer) + .unwrap() + { + HandshakeAuthentication::NeedMoreData(handshake) => handshake, + HandshakeAuthentication::Complete { .. } => { + panic!("Should have needed more data with 1 byte") + } + }; + + // Feed a bit more data - still probably not enough. + let partial_data_2 = &init_version_buffer[..5]; + let returned_handshake = match returned_handshake + .receive_version(partial_data_2, &mut packet_buffer) + .unwrap() + { + HandshakeAuthentication::NeedMoreData(handshake) => handshake, + HandshakeAuthentication::Complete { .. } => { + panic!("Should have needed more data with 5 bytes") + } + }; + + // Now provide the complete data. + match returned_handshake + .receive_version(&init_version_buffer, &mut packet_buffer) + .unwrap() + { + HandshakeAuthentication::Complete { bytes_consumed, .. } => { + assert_eq!(bytes_consumed, init_version_buffer.len()); + } + HandshakeAuthentication::NeedMoreData(_) => { + panic!("Should have completed with full data") + } + } + } + #[test] fn test_handshake_no_garbage_terminator() { - let mut handshake_buffer = [0u8; NUM_ELLIGATOR_SWIFT_BYTES]; - let mut rng = rand::thread_rng(); - let curve = Secp256k1::signing_only(); + // Create a handshake and bring it to the SentVersion state to test split_garbage + let handshake = Handshake::::new(Network::Bitcoin, Role::Initiator).unwrap(); + let mut buffer = vec![0u8; 64]; + let handshake = handshake.send_key(None, &mut buffer).unwrap(); - let mut handshake = Handshake::new_with_rng( - Network::Bitcoin, - Role::Initiator, - None, - &mut handshake_buffer, - &mut rng, - &curve, - ) - .expect("Handshake creation should succeed"); + // Create a fake peer key to receive + let fake_peer_key = [0u8; NUM_ELLIGATOR_SWIFT_BYTES]; + let handshake = handshake.receive_key(fake_peer_key).unwrap(); - // Skipping material creation and just placing a mock terminator. - handshake.remote_garbage_terminator = Some([0xFF; NUM_GARBAGE_TERMINTOR_BYTES]); + // Send version to get to SentVersion state + let mut version_buffer = vec![0u8; 1024]; + let handshake = handshake.send_version(&mut version_buffer, None).unwrap(); - // Test with a buffer that is too long. + // Test with a buffer that is too long (should fail to find terminator) let test_buffer = vec![0; MAX_NUM_GARBAGE_BYTES + NUM_GARBAGE_TERMINTOR_BYTES]; let result = handshake.split_garbage(&test_buffer); assert!(matches!(result, Err(Error::NoGarbageTerminator))); - // Test with a buffer that's just short of the required length. + // Test with a buffer that's just short of the required length let short_buffer = vec![0; MAX_NUM_GARBAGE_BYTES + NUM_GARBAGE_TERMINTOR_BYTES - 1]; let result = handshake.split_garbage(&short_buffer); assert!(matches!(result, Err(Error::CiphertextTooSmall))); diff --git a/protocol/src/io.rs b/protocol/src/io.rs index 982f2b7..3498a98 100644 --- a/protocol/src/io.rs +++ b/protocol/src/io.rs @@ -20,8 +20,9 @@ use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use crate::{ + handshake::{self, HandshakeAuthentication}, Error, Handshake, InboundCipher, OutboundCipher, PacketType, Role, NUM_ELLIGATOR_SWIFT_BYTES, - NUM_GARBAGE_TERMINTOR_BYTES, NUM_INITIAL_HANDSHAKE_BUFFER_BYTES, VERSION_CONTENT, + NUM_INITIAL_HANDSHAKE_BUFFER_BYTES, }; /// A decrypted BIP324 payload with its packet type. @@ -141,6 +142,8 @@ pub struct AsyncProtocol { impl AsyncProtocol { /// New protocol session which completes the initial handshake and returns a handler. /// + /// This function is *not* cancellation safe. + /// /// # Arguments /// /// * `network` - Network which both parties are operating on. @@ -171,77 +174,68 @@ impl AsyncProtocol { R: AsyncRead + Unpin + Send, W: AsyncWrite + Unpin + Send, { - let garbage_len = match garbage { - Some(slice) => slice.len(), - None => 0, - }; - let mut ellswift_buffer = vec![0u8; NUM_ELLIGATOR_SWIFT_BYTES + garbage_len]; - let mut handshake = Handshake::new(network, role, garbage, &mut ellswift_buffer)?; - - // Send initial key to remote. - writer.write_all(&ellswift_buffer).await?; + let handshake = Handshake::::new(network, role)?; + + // Send local public key and optional garbage. + let key_buffer_len = Handshake::::send_key_len(garbage); + let mut key_buffer = vec![0u8; key_buffer_len]; + let handshake = handshake.send_key(garbage, &mut key_buffer)?; + writer + .write_all(&key_buffer[..handshake.bytes_written()]) + .await?; writer.flush().await?; - // Read remote's initial key. - let mut remote_ellswift_buffer = [0u8; 64]; + // Read remote's public key. + let mut remote_ellswift_buffer = [0u8; NUM_ELLIGATOR_SWIFT_BYTES]; reader.read_exact(&mut remote_ellswift_buffer).await?; - - let num_version_packet_bytes = OutboundCipher::encryption_buffer_len(VERSION_CONTENT.len()); - let num_decoy_packets_bytes: usize = match decoys { - Some(decoys) => decoys - .iter() - .map(|decoy| OutboundCipher::encryption_buffer_len(decoy.len())) - .sum(), - None => 0, - }; - - // Complete materials and send terminator to remote. - // Not exposing decoy packets yet. - let mut terminator_and_version_buffer = - vec![ - 0u8; - NUM_GARBAGE_TERMINTOR_BYTES + num_version_packet_bytes + num_decoy_packets_bytes - ]; - handshake.complete_materials( - remote_ellswift_buffer, - &mut terminator_and_version_buffer, - decoys, - )?; - writer.write_all(&terminator_and_version_buffer).await?; + let handshake = handshake.receive_key(remote_ellswift_buffer)?; + + // Send garbage terminator, decoys, and version. + let version_buffer_len = Handshake::::send_version_len(decoys); + let mut version_buffer = vec![0u8; version_buffer_len]; + let handshake = handshake.send_version(&mut version_buffer, decoys)?; + writer + .write_all(&version_buffer[..handshake.bytes_written()]) + .await?; writer.flush().await?; // Receive and authenticate remote garbage and version. - // Keep pulling bytes from the buffer until the garbage is flushed. let mut remote_garbage_and_version_buffer = Vec::with_capacity(NUM_INITIAL_HANDSHAKE_BUFFER_BYTES); let mut packet_buffer = vec![0u8; NUM_INITIAL_HANDSHAKE_BUFFER_BYTES]; + let mut handshake = handshake; loop { let mut temp_buffer = [0u8; NUM_INITIAL_HANDSHAKE_BUFFER_BYTES]; match reader.read(&mut temp_buffer).await { - // No data available right now, continue. - Ok(0) => { - continue; - } + Ok(0) => continue, Ok(bytes_read) => { remote_garbage_and_version_buffer.extend_from_slice(&temp_buffer[..bytes_read]); - match handshake.authenticate_garbage_and_version( - &remote_garbage_and_version_buffer, - &mut packet_buffer, - ) { - Ok(()) => break, - // Not enough data, continue reading. - Err(Error::CiphertextTooSmall) => continue, + handshake = match handshake + .receive_version(&remote_garbage_and_version_buffer, &mut packet_buffer) + { + Ok(HandshakeAuthentication::Complete { cipher, .. }) => { + let (inbound_cipher, outbound_cipher) = cipher.into_split(); + return Ok(Self { + reader: AsyncProtocolReader { + inbound_cipher, + state: DecryptState::init_reading_length(), + }, + writer: AsyncProtocolWriter { outbound_cipher }, + }); + } + Ok(HandshakeAuthentication::NeedMoreData(handshake)) => handshake, Err(Error::BufferTooSmall { required_bytes }) => { packet_buffer.resize(required_bytes, 0); - continue; + return Err(ProtocolError::Internal(Error::BufferTooSmall { + required_bytes, + })); } Err(e) => return Err(ProtocolError::Internal(e)), - } + }; } Err(e) => match e.kind() { - // No data available or interrupted, retry. std::io::ErrorKind::WouldBlock | std::io::ErrorKind::Interrupted => { continue; } @@ -249,17 +243,6 @@ impl AsyncProtocol { }, } } - - let cipher_session = handshake.finalize()?; - let (inbound_cipher, outbound_cipher) = cipher_session.into_split(); - - Ok(Self { - reader: AsyncProtocolReader { - inbound_cipher, - state: DecryptState::init_reading_length(), - }, - writer: AsyncProtocolWriter { outbound_cipher }, - }) } /// Read reference for packet reading operations. @@ -365,7 +348,8 @@ impl AsyncProtocolReader { self.inbound_cipher .decrypt(packet_bytes, &mut plaintext_buffer, None)?; self.state = DecryptState::init_reading_length(); - return Ok(Payload::new(plaintext_buffer, packet_type)); + // Skip the header byte (first byte) which contains the packet type + return Ok(Payload::new(plaintext_buffer[1..].to_vec(), packet_type)); } } } diff --git a/protocol/src/lib.rs b/protocol/src/lib.rs index 70910d7..91679fb 100644 --- a/protocol/src/lib.rs +++ b/protocol/src/lib.rs @@ -39,7 +39,10 @@ use bitcoin_hashes::{hkdf, sha256, Hkdf}; pub use bitcoin::Network; -pub use handshake::{Handshake, NUM_INITIAL_HANDSHAKE_BUFFER_BYTES}; +pub use handshake::{ + Handshake, HandshakeAuthentication, Initialized, ReceivedKey, SentKey, SentVersion, + NUM_INITIAL_HANDSHAKE_BUFFER_BYTES, +}; // Re-exports from io module (async I/O types for backwards compatibility) #[cfg(any(feature = "futures", feature = "tokio"))] pub use io::{ @@ -82,8 +85,6 @@ pub enum Error { /// The remote sent the maximum amount of garbage bytes without /// a garbage terminator in the handshake. NoGarbageTerminator, - /// A handshake step was not completed in the proper order. - HandshakeOutOfOrder, /// The remote peer is communicating on the V1 protocol. V1Protocol, /// Not able to generate secret material. @@ -108,7 +109,6 @@ impl fmt::Display for Error { Error::NoGarbageTerminator => { write!(f, "More than 4095 bytes of garbage recieved in the handshake before a terminator was sent.") } - Error::HandshakeOutOfOrder => write!(f, "Handshake flow out of sequence."), Error::SecretGeneration(e) => write!(f, "Cannot generate secrets: {e:?}."), Error::Decryption(e) => write!(f, "Decrytion error: {e:?}."), Error::V1Protocol => write!(f, "The remote peer is communicating on the V1 protocol."), @@ -127,7 +127,6 @@ impl std::error::Error for Error { Error::CiphertextTooSmall => None, Error::BufferTooSmall { .. } => None, Error::NoGarbageTerminator => None, - Error::HandshakeOutOfOrder => None, Error::V1Protocol => None, Error::SecretGeneration(e) => Some(e), Error::Decryption(e) => Some(e), @@ -472,7 +471,7 @@ pub struct CipherSession { } impl CipherSession { - fn new(materials: SessionKeyMaterial, role: Role) -> Self { + pub(crate) fn new(materials: SessionKeyMaterial, role: Role) -> Self { match role { Role::Initiator => { let initiator_length_cipher = FSChaCha20::new(materials.initiator_length_key); @@ -726,107 +725,90 @@ mod tests { #[test] fn test_handshake_with_garbage_and_decoys() { + use crate::handshake::{HandshakeAuthentication, Initialized}; + // Define the garbage and decoys that the initiator is sending to the responder. let initiator_garbage = vec![1u8, 2u8, 3u8]; let initiator_decoys: Vec<&[u8]> = vec![&[6u8, 7u8], &[8u8, 0u8]]; - let num_initiator_decoys_bytes = initiator_decoys - .iter() - .map(|slice| slice.len()) - .sum::() - + NUM_PACKET_OVERHEAD_BYTES * initiator_decoys.len(); - let num_initiator_version_bytes = VERSION_CONTENT.len() + NUM_PACKET_OVERHEAD_BYTES; - // Buffer for initiator to write to and responder to read from. - let mut initiator_buffer = vec![ - 0u8; - NUM_ELLIGATOR_SWIFT_BYTES - + initiator_garbage.len() - + NUM_GARBAGE_TERMINTOR_BYTES - + num_initiator_decoys_bytes - + num_initiator_version_bytes - ]; // Define the garbage and decoys that the responder is sending to the initiator. let responder_garbage = vec![4u8, 5u8]; let responder_decoys: Vec<&[u8]> = vec![&[10u8, 11u8], &[12u8], &[13u8, 14u8, 15u8]]; - let num_responder_decoys_bytes = responder_decoys - .iter() - .map(|slice| slice.len()) - .sum::() - + NUM_PACKET_OVERHEAD_BYTES * responder_decoys.len(); - let num_responder_version_bytes = VERSION_CONTENT.len() + NUM_PACKET_OVERHEAD_BYTES; - // Buffer for responder to write to and initiator to read from. - let mut responder_buffer = vec![ - 0u8; - NUM_ELLIGATOR_SWIFT_BYTES - + responder_garbage.len() - + NUM_GARBAGE_TERMINTOR_BYTES - + num_responder_decoys_bytes - + num_responder_version_bytes - ]; - - // The initiator's handshake writes its 64 byte elligator swift key and garbage to their buffer to send to the responder. - let mut initiator_handshake = Handshake::new( - Network::Bitcoin, - Role::Initiator, - Some(&initiator_garbage), - &mut initiator_buffer[..NUM_ELLIGATOR_SWIFT_BYTES + initiator_garbage.len()], - ) - .unwrap(); - // The responder also writes its 64 byte elligator swift key and garbage to their buffer to send to the initiator. - let mut responder_handshake = Handshake::new( - Network::Bitcoin, - Role::Responder, - Some(&responder_garbage), - &mut responder_buffer[..NUM_ELLIGATOR_SWIFT_BYTES + responder_garbage.len()], - ) - .unwrap(); + // Create initiator handshake + let init_handshake = + Handshake::::new(Network::Bitcoin, Role::Initiator).unwrap(); - // The responder has received the initiator's initial material so can complete the secrets. - // With the secrets calculated, the responder can send along the garbage terminator, decoys, and version packet. - responder_handshake - .complete_materials( - initiator_buffer[..NUM_ELLIGATOR_SWIFT_BYTES] - .try_into() - .unwrap(), - &mut responder_buffer[NUM_ELLIGATOR_SWIFT_BYTES + responder_garbage.len()..], - Some(&responder_decoys), - ) + // Send initiator key + garbage + let mut init_key_buffer = vec![0u8; NUM_ELLIGATOR_SWIFT_BYTES + initiator_garbage.len()]; + let init_handshake = init_handshake + .send_key(Some(&initiator_garbage), &mut init_key_buffer) .unwrap(); - // Once the initiator receives the responder's response it can also complete the secrets. - // The initiator then needs to send along their recently calculated garbage terminator, decoys, and the version packet. - initiator_handshake - .complete_materials( - responder_buffer[..NUM_ELLIGATOR_SWIFT_BYTES] - .try_into() - .unwrap(), - &mut initiator_buffer[NUM_ELLIGATOR_SWIFT_BYTES + initiator_garbage.len()..], - Some(&initiator_decoys), - ) + // Create responder handshake + let resp_handshake = + Handshake::::new(Network::Bitcoin, Role::Responder).unwrap(); + + // Send responder key + garbage + let mut resp_key_buffer = vec![0u8; NUM_ELLIGATOR_SWIFT_BYTES + responder_garbage.len()]; + let resp_handshake = resp_handshake + .send_key(Some(&responder_garbage), &mut resp_key_buffer) .unwrap(); - // The initiator verifies the second half of the responders message which - // includes the garbage, garbage terminator, decoys, and version packet. - let mut packet_buffer = vec![0u8; NUM_INITIAL_HANDSHAKE_BUFFER_BYTES]; - initiator_handshake - .authenticate_garbage_and_version( - &responder_buffer[NUM_ELLIGATOR_SWIFT_BYTES..], - &mut packet_buffer, - ) + // Initiator receives responder's key + let init_handshake = init_handshake + .receive_key(resp_key_buffer[..64].try_into().unwrap()) .unwrap(); - // The responder verifies the second message from the initiator which - // includes the garbage, garbage terminator, decoys, and version packet. - responder_handshake - .authenticate_garbage_and_version( - &initiator_buffer[NUM_ELLIGATOR_SWIFT_BYTES..], - &mut packet_buffer, - ) + // Responder receives initiator's key + let resp_handshake = resp_handshake + .receive_key(init_key_buffer[..64].try_into().unwrap()) .unwrap(); - let mut alice = initiator_handshake.finalize().unwrap(); - let mut bob = responder_handshake.finalize().unwrap(); + // Initiator sends version + let mut init_version_buffer = vec![0u8; 1024]; + let init_handshake = init_handshake + .send_version(&mut init_version_buffer, Some(&initiator_decoys)) + .unwrap(); + let init_version_len = init_handshake.bytes_written(); + + // Responder sends version + let mut resp_version_buffer = vec![0u8; 1024]; + let resp_handshake = resp_handshake + .send_version(&mut resp_version_buffer, Some(&responder_decoys)) + .unwrap(); + let resp_version_len = resp_handshake.bytes_written(); + + // Complete handshakes + let mut packet_buffer = vec![0u8; NUM_INITIAL_HANDSHAKE_BUFFER_BYTES]; + + // Initiator receives responder's version + let full_resp_message = [ + &responder_garbage[..], + &resp_version_buffer[..resp_version_len], + ] + .concat(); + let mut alice = match init_handshake + .receive_version(&full_resp_message, &mut packet_buffer) + .unwrap() + { + HandshakeAuthentication::Complete { cipher, .. } => cipher, + HandshakeAuthentication::NeedMoreData(_) => panic!("Should have completed"), + }; + + // Responder receives initiator's version + let full_init_message = [ + &initiator_garbage[..], + &init_version_buffer[..init_version_len], + ] + .concat(); + let mut bob = match resp_handshake + .receive_version(&full_init_message, &mut packet_buffer) + .unwrap() + { + HandshakeAuthentication::Complete { cipher, .. } => cipher, + HandshakeAuthentication::NeedMoreData(_) => panic!("Should have completed"), + }; let message = b"Hello world".to_vec(); let packet_len = OutboundCipher::encryption_buffer_len(message.len()); @@ -887,42 +869,67 @@ mod tests { #[test] fn test_partial_decodings() { + use crate::handshake::{HandshakeAuthentication, Initialized}; let mut rng = rand::thread_rng(); - let mut init_message = vec![0u8; 64]; - let mut init_handshake = - Handshake::new(Network::Bitcoin, Role::Initiator, None, &mut init_message).unwrap(); + // Create initiator handshake + let init_handshake = + Handshake::::new(Network::Bitcoin, Role::Initiator).unwrap(); - let mut resp_message = vec![0u8; 100]; - let mut resp_handshake = - Handshake::new(Network::Bitcoin, Role::Responder, None, &mut resp_message).unwrap(); + // Send initiator key (no garbage) + let mut init_key_buffer = vec![0u8; 64]; + let init_handshake = init_handshake.send_key(None, &mut init_key_buffer).unwrap(); - resp_handshake - .complete_materials( - init_message.try_into().unwrap(), - &mut resp_message[64..], - None, - ) + // Create responder handshake + let resp_handshake = + Handshake::::new(Network::Bitcoin, Role::Responder).unwrap(); + + // Send responder key (no garbage) + let mut resp_key_buffer = vec![0u8; 64]; + let resp_handshake = resp_handshake.send_key(None, &mut resp_key_buffer).unwrap(); + + // Initiator receives responder's key + let init_handshake = init_handshake + .receive_key(resp_key_buffer.try_into().unwrap()) .unwrap(); - let mut init_finalize_message = vec![0u8; 36]; - init_handshake - .complete_materials( - resp_message[0..64].try_into().unwrap(), - &mut init_finalize_message, - None, - ) + + // Responder receives initiator's key + let resp_handshake = resp_handshake + .receive_key(init_key_buffer.try_into().unwrap()) .unwrap(); - let mut packet_buffer = vec![0u8; NUM_INITIAL_HANDSHAKE_BUFFER_BYTES]; - init_handshake - .authenticate_garbage_and_version(&resp_message[64..], &mut packet_buffer) + // Initiator sends version + let mut init_version_buffer = vec![0u8; 36]; + let init_handshake = init_handshake + .send_version(&mut init_version_buffer, None) .unwrap(); - resp_handshake - .authenticate_garbage_and_version(&init_finalize_message, &mut packet_buffer) + + // Responder sends version + let mut resp_version_buffer = vec![0u8; 36]; + let resp_handshake = resp_handshake + .send_version(&mut resp_version_buffer, None) .unwrap(); - let mut alice = init_handshake.finalize().unwrap(); - let mut bob = resp_handshake.finalize().unwrap(); + // Complete handshakes + let mut packet_buffer = vec![0u8; NUM_INITIAL_HANDSHAKE_BUFFER_BYTES]; + + // Initiator receives responder's version + let mut alice = match init_handshake + .receive_version(&resp_version_buffer, &mut packet_buffer) + .unwrap() + { + HandshakeAuthentication::Complete { cipher, .. } => cipher, + HandshakeAuthentication::NeedMoreData(_) => panic!("Should have completed"), + }; + + // Responder receives initiator's version + let mut bob = match resp_handshake + .receive_version(&init_version_buffer, &mut packet_buffer) + .unwrap() + { + HandshakeAuthentication::Complete { cipher, .. } => cipher, + HandshakeAuthentication::NeedMoreData(_) => panic!("Should have completed"), + }; let mut message_to_bob = Vec::new(); let message = gen_garbage(420, &mut rng); diff --git a/protocol/tests/round_trips.rs b/protocol/tests/round_trips.rs index 86f9352..a1e5793 100644 --- a/protocol/tests/round_trips.rs +++ b/protocol/tests/round_trips.rs @@ -5,43 +5,68 @@ const PORT: u16 = 18444; #[test] #[cfg(feature = "std")] fn hello_world_happy_path() { - use bip324::{Handshake, PacketType, Role}; + use bip324::{ + Handshake, HandshakeAuthentication, Initialized, PacketType, ReceivedKey, Role, + NUM_INITIAL_HANDSHAKE_BUFFER_BYTES, + }; use bitcoin::Network; - let mut init_message = vec![0u8; 64]; - let mut init_handshake = - Handshake::new(Network::Bitcoin, Role::Initiator, None, &mut init_message).unwrap(); + // Create initiator handshake + let init_handshake = Handshake::::new(Network::Bitcoin, Role::Initiator).unwrap(); - let mut resp_message = vec![0u8; 100]; - let mut resp_handshake = - Handshake::new(Network::Bitcoin, Role::Responder, None, &mut resp_message).unwrap(); + // Send initiator key + let mut init_key_buffer = vec![0u8; Handshake::::send_key_len(None)]; + let init_handshake = init_handshake.send_key(None, &mut init_key_buffer).unwrap(); - resp_handshake - .complete_materials( - init_message.try_into().unwrap(), - &mut resp_message[64..], - None, - ) + // Create responder handshake + let resp_handshake = Handshake::::new(Network::Bitcoin, Role::Responder).unwrap(); + + // Send responder key + let mut resp_key_buffer = vec![0u8; Handshake::::send_key_len(None)]; + let resp_handshake = resp_handshake.send_key(None, &mut resp_key_buffer).unwrap(); + + // Initiator receives responder's key + let init_handshake = init_handshake + .receive_key(resp_key_buffer[..64].try_into().unwrap()) .unwrap(); - let mut init_finalize_message = vec![0u8; 36]; - init_handshake - .complete_materials( - resp_message[0..64].try_into().unwrap(), - &mut init_finalize_message, - None, - ) + + // Responder receives initiator's key + let resp_handshake = resp_handshake + .receive_key(init_key_buffer[..64].try_into().unwrap()) .unwrap(); - let mut packet_buffer = vec![0u8; 4096]; - init_handshake - .authenticate_garbage_and_version(&resp_message[64..], &mut packet_buffer) + // Initiator sends version + let mut init_version_buffer = vec![0u8; Handshake::::send_version_len(None)]; + let init_handshake = init_handshake + .send_version(&mut init_version_buffer, None) .unwrap(); - resp_handshake - .authenticate_garbage_and_version(&init_finalize_message, &mut packet_buffer) + + // Responder sends version + let mut resp_version_buffer = vec![0u8; Handshake::::send_version_len(None)]; + let resp_handshake = resp_handshake + .send_version(&mut resp_version_buffer, None) .unwrap(); - let mut alice = init_handshake.finalize().unwrap(); - let mut bob = resp_handshake.finalize().unwrap(); + // Complete handshakes + let mut packet_buffer = vec![0u8; NUM_INITIAL_HANDSHAKE_BUFFER_BYTES]; + + // Initiator receives responder's version + let mut alice = match init_handshake + .receive_version(&resp_version_buffer, &mut packet_buffer) + .unwrap() + { + HandshakeAuthentication::Complete { cipher, .. } => cipher, + HandshakeAuthentication::NeedMoreData(_) => panic!("Should have completed"), + }; + + // Responder receives initiator's version + let mut bob = match resp_handshake + .receive_version(&init_version_buffer, &mut packet_buffer) + .unwrap() + { + HandshakeAuthentication::Complete { cipher, .. } => cipher, + HandshakeAuthentication::NeedMoreData(_) => panic!("Should have completed"), + }; // Alice and Bob can freely exchange encrypted messages using the packet handler returned by each handshake. let message = b"Hello world".to_vec(); @@ -106,48 +131,62 @@ fn regtest_handshake() { use bip324::{ serde::{deserialize, serialize, NetworkMessage}, - Handshake, PacketType, + Handshake, HandshakeAuthentication, Initialized, PacketType, ReceivedKey, + NUM_INITIAL_HANDSHAKE_BUFFER_BYTES, }; use bitcoin::p2p::{message_network::VersionMessage, Address, ServiceFlags}; let bitcoind = regtest_process(TransportVersion::V2); let mut stream = TcpStream::connect(bitcoind.params.p2p_socket.unwrap()).unwrap(); - let mut public_key = [0u8; 64]; - let mut handshake = Handshake::new( - bip324::Network::Regtest, - bip324::Role::Initiator, - None, - &mut public_key, - ) - .unwrap(); + + // Initialize handshake + let handshake = + Handshake::::new(bip324::Network::Regtest, bip324::Role::Initiator).unwrap(); + + // Send our public key + let mut public_key = vec![0u8; Handshake::::send_key_len(None)]; + let handshake = handshake.send_key(None, &mut public_key).unwrap(); println!("Writing public key to the remote node"); stream.write_all(&public_key).unwrap(); stream.flush().unwrap(); + + // Read remote public key let mut remote_public_key = [0u8; 64]; println!("Reading the remote node public key"); stream.read_exact(&mut remote_public_key).unwrap(); - let mut local_garbage_terminator_message = [0u8; 36]; + + // Process remote key + let handshake = handshake.receive_key(remote_public_key).unwrap(); + + // Send garbage terminator and version + let mut local_garbage_terminator_message = + vec![0u8; Handshake::::send_version_len(None)]; println!("Sending our garbage terminator"); - handshake - .complete_materials( - remote_public_key, - &mut local_garbage_terminator_message, - None, - ) + let handshake = handshake + .send_version(&mut local_garbage_terminator_message, None) .unwrap(); stream.write_all(&local_garbage_terminator_message).unwrap(); stream.flush().unwrap(); + + // Read and authenticate remote response let mut max_response = [0; 4096]; println!("Reading the response buffer"); let size = stream.read(&mut max_response).unwrap(); let response = &mut max_response[..size]; println!("Authenticating the handshake"); - let mut packet_buffer = vec![0u8; 4096]; - handshake - .authenticate_garbage_and_version(response, &mut packet_buffer) - .unwrap(); - println!("Finalizing the handshake"); - let cipher_session = handshake.finalize().unwrap(); + let mut packet_buffer = vec![0u8; NUM_INITIAL_HANDSHAKE_BUFFER_BYTES]; + + let cipher_session = match handshake + .receive_version(response, &mut packet_buffer) + .unwrap() + { + HandshakeAuthentication::Complete { cipher, .. } => { + println!("Finalizing the handshake"); + cipher + } + HandshakeAuthentication::NeedMoreData(_) => panic!("Should have completed"), + }; + let (mut decrypter, mut encrypter) = cipher_session.into_split(); let now = SystemTime::now() .duration_since(UNIX_EPOCH) @@ -198,18 +237,15 @@ fn regtest_handshake_v1_only() { net::TcpStream, }; - use bip324::Handshake; + use bip324::{Handshake, Initialized}; let bitcoind = regtest_process(TransportVersion::V1); let mut stream = TcpStream::connect(bitcoind.params.p2p_socket.unwrap()).unwrap(); - let mut public_key = [0u8; 64]; - let _ = Handshake::new( - bip324::Network::Regtest, - bip324::Role::Initiator, - None, - &mut public_key, - ) - .unwrap(); + + let handshake = + Handshake::::new(bip324::Network::Regtest, bip324::Role::Initiator).unwrap(); + let mut public_key = vec![0u8; Handshake::::send_key_len(None)]; + let _handshake = handshake.send_key(None, &mut public_key).unwrap(); println!("Writing public key to the remote node"); stream.write_all(&public_key).unwrap(); stream.flush().unwrap(); From 0ebf7731a28030fff2e2704721e85f9119472abc Mon Sep 17 00:00:00 2001 From: Nick Johnson Date: Thu, 26 Jun 2025 12:25:57 -0700 Subject: [PATCH 2/5] feat: add decrypt_in_place to outbound cipher The BIP-324 protocol is packet in nature, not stream based, and the encrypted packets are not the same size as plaintext. They get length bits added to the front and authentication bits tacked on the back. So encryption in place would be a huge pain. And generally decryption in place is weird too, because there are these bytes on either end which no longer make sense to the plaintext. However, there is a solid use case for *decryption* in place in the handshake where packet content is not really cared about. --- protocol/src/lib.rs | 126 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 111 insertions(+), 15 deletions(-) diff --git a/protocol/src/lib.rs b/protocol/src/lib.rs index 91679fb..ae2dc8d 100644 --- a/protocol/src/lib.rs +++ b/protocol/src/lib.rs @@ -311,7 +311,7 @@ pub struct InboundCipher { } impl InboundCipher { - /// Decrypt the length of the next inbound packet. + /// Decrypt the length of the packet's payload. /// /// Note that this returns the length of the remaining packet data /// to be read from the stream (header + contents + tag), not just the contents. @@ -339,11 +339,50 @@ impl InboundCipher { packet_len - NUM_TAG_BYTES } + /// Decrypt an inbound packet in-place. + /// + /// # Arguments + /// + /// * `ciphertext` - A mutable buffer containing the packet from the peer excluding + /// the first 3 length bytes. It should contain the header, contents, and authentication tag. + /// This buffer will be modified in-place during decryption. + /// * `aad` - Optional associated authenticated data. + /// + /// # Returns + /// + /// A `Result` containing: + /// * `Ok((PacketType, &[u8]))`: A tuple of the packet type and a slice pointing to the + /// decrypted plaintext within the input buffer. The first byte of the slice is the + /// header byte containing protocol flags. + /// * `Err(Error)`: An error that occurred during decryption. + /// + /// # Errors + /// + /// * `CiphertextTooSmall` - Ciphertext argument does not contain a whole packet. + /// * Decryption errors for any failures such as a tag mismatch. + pub fn decrypt_in_place<'a>( + &mut self, + ciphertext: &'a mut [u8], + aad: Option<&[u8]>, + ) -> Result<(PacketType, &'a [u8]), Error> { + let auth = aad.unwrap_or_default(); + // Check minimum size of ciphertext. + if ciphertext.len() < NUM_TAG_BYTES { + return Err(Error::CiphertextTooSmall); + } + let (msg, tag) = ciphertext.split_at_mut(ciphertext.len() - NUM_TAG_BYTES); + + self.packet_cipher + .decrypt(auth, msg, tag.try_into().expect("16 byte tag"))?; + + Ok((PacketType::from_byte(&msg[0]), msg)) + } + /// Decrypt an inbound packet. /// /// # Arguments /// - /// * `packet_payload` - The packet from the peer excluding the first 3 length bytes. It should contain + /// * `ciphertext` - The packet from the peer excluding the first 3 length bytes. It should contain /// the header, contents, and authentication tag. /// * `plaintext_buffer` - Mutable buffer to write plaintext. Note that the first byte is the header byte /// containing protocol flags. @@ -362,16 +401,16 @@ impl InboundCipher { /// * Decryption errors for any failures such as a tag mismatch. pub fn decrypt( &mut self, - packet_payload: &[u8], + ciphertext: &[u8], plaintext_buffer: &mut [u8], aad: Option<&[u8]>, ) -> Result { let auth = aad.unwrap_or_default(); // Check minimum size of ciphertext. - if packet_payload.len() < NUM_TAG_BYTES { + if ciphertext.len() < NUM_TAG_BYTES { return Err(Error::CiphertextTooSmall); } - let (msg, tag) = packet_payload.split_at(packet_payload.len() - NUM_TAG_BYTES); + let (msg, tag) = ciphertext.split_at(ciphertext.len() - NUM_TAG_BYTES); // Check that the contents buffer is large enough. if plaintext_buffer.len() < msg.len() { return Err(Error::BufferTooSmall { @@ -407,7 +446,7 @@ impl OutboundCipher { /// # Arguments /// /// * `plaintext` - Plaintext contents to be encrypted. - /// * `packet_buffer` - Buffer to write packet bytes to which must have enough capacity + /// * `ciphertext_buffer` - Buffer to write packet bytes to which must have enough capacity /// as calculated by `encryption_buffer_len(plaintext.len())`. /// * `packet_type` - Is this a genuine packet or a decoy. /// * `aad` - Optional associated authenticated data. @@ -419,12 +458,12 @@ impl OutboundCipher { pub fn encrypt( &mut self, plaintext: &[u8], - packet_buffer: &mut [u8], + ciphertext_buffer: &mut [u8], packet_type: PacketType, aad: Option<&[u8]>, ) -> Result<(), Error> { // Validate buffer capacity. - if packet_buffer.len() < Self::encryption_buffer_len(plaintext.len()) { + if ciphertext_buffer.len() < Self::encryption_buffer_len(plaintext.len()) { return Err(Error::BufferTooSmall { required_bytes: Self::encryption_buffer_len(plaintext.len()), }); @@ -436,14 +475,15 @@ impl OutboundCipher { let plaintext_end_index = plaintext_start_index + plaintext_length; // Set header byte. - packet_buffer[header_index] = packet_type.to_byte(); - packet_buffer[plaintext_start_index..plaintext_end_index].copy_from_slice(plaintext); + ciphertext_buffer[header_index] = packet_type.to_byte(); + ciphertext_buffer[plaintext_start_index..plaintext_end_index].copy_from_slice(plaintext); // Encrypt header byte and plaintext in place and produce authentication tag. let auth = aad.unwrap_or_default(); - let tag = self - .packet_cipher - .encrypt(auth, &mut packet_buffer[header_index..plaintext_end_index]); + let tag = self.packet_cipher.encrypt( + auth, + &mut ciphertext_buffer[header_index..plaintext_end_index], + ); // Encrypt plaintext length. let mut content_len = [0u8; 3]; @@ -451,8 +491,8 @@ impl OutboundCipher { self.length_cipher.crypt(&mut content_len); // Copy over encrypted length and the tag to the final packet (plaintext already encrypted). - packet_buffer[0..NUM_LENGTH_BYTES].copy_from_slice(&content_len); - packet_buffer[plaintext_end_index..(plaintext_end_index + NUM_TAG_BYTES)] + ciphertext_buffer[0..NUM_LENGTH_BYTES].copy_from_slice(&content_len); + ciphertext_buffer[plaintext_end_index..(plaintext_end_index + NUM_TAG_BYTES)] .copy_from_slice(&tag); Ok(()) @@ -609,6 +649,62 @@ mod tests { assert_eq!(message, plaintext_buffer[1..].to_vec()); // Skip header byte } + #[test] + fn test_decrypt_in_place() { + let alice = + SecretKey::from_str("61062ea5071d800bbfd59e2e8b53d47d194b095ae5a4df04936b49772ef0d4d7") + .unwrap(); + let elliswift_alice = ElligatorSwift::from_str("ec0adff257bbfe500c188c80b4fdd640f6b45a482bbc15fc7cef5931deff0aa186f6eb9bba7b85dc4dcc28b28722de1e3d9108b985e2967045668f66098e475b").unwrap(); + let elliswift_bob = ElligatorSwift::from_str("a4a94dfce69b4a2a0a099313d10f9f7e7d649d60501c9e1d274c300e0d89aafaffffffffffffffffffffffffffffffffffffffffffffffffffffffff8faf88d5").unwrap(); + let session_keys = SessionKeyMaterial::from_ecdh( + elliswift_alice, + elliswift_bob, + alice, + ElligatorSwiftParty::A, + Network::Bitcoin, + ) + .unwrap(); + let mut alice_cipher = CipherSession::new(session_keys.clone(), Role::Initiator); + let mut bob_cipher = CipherSession::new(session_keys, Role::Responder); + + // Test with a genuine packet + let message = b"Test in-place decryption".to_vec(); + let mut enc_packet = vec![0u8; OutboundCipher::encryption_buffer_len(message.len())]; + alice_cipher + .outbound() + .encrypt(&message, &mut enc_packet, PacketType::Genuine, None) + .unwrap(); + + // Decrypt in-place + let mut ciphertext = enc_packet[NUM_LENGTH_BYTES..].to_vec(); + let (packet_type, plaintext) = bob_cipher + .inbound() + .decrypt_in_place(&mut ciphertext, None) + .unwrap(); + + assert_eq!(PacketType::Genuine, packet_type); + assert_eq!(message, plaintext[1..].to_vec()); // Skip header byte + + // Test with a decoy packet and AAD + let message2 = b"Decoy with AAD".to_vec(); + let aad = b"additional authenticated data"; + let mut enc_packet2 = vec![0u8; OutboundCipher::encryption_buffer_len(message2.len())]; + bob_cipher + .outbound() + .encrypt(&message2, &mut enc_packet2, PacketType::Decoy, Some(aad)) + .unwrap(); + + // Decrypt in-place with AAD + let mut ciphertext2 = enc_packet2[NUM_LENGTH_BYTES..].to_vec(); + let (packet_type2, plaintext2) = alice_cipher + .inbound() + .decrypt_in_place(&mut ciphertext2, Some(aad)) + .unwrap(); + + assert_eq!(PacketType::Decoy, packet_type2); + assert_eq!(message2, plaintext2[1..].to_vec()); // Skip header byte + } + #[test] fn test_fuzz_packets() { let mut rng = rand::thread_rng(); From a3f343627658959ffe77b9277a0cdb3f239b37f7 Mon Sep 17 00:00:00 2001 From: Nick Johnson Date: Thu, 26 Jun 2025 15:09:22 -0700 Subject: [PATCH 3/5] feat!: decrypt in place for handshake Update the receive_version to take a mutable input buffer instead of an immutable input and a mutable output buffer. During the handshake, callers don't care about decoy packets or version packet. And they really don't want to deal with the variable sized output buffer. --- protocol/benches/cipher_session.rs | 8 +- protocol/fuzz/fuzz_targets/handshake.rs | 10 +- protocol/src/handshake.rs | 182 ++++++++++++++++-------- protocol/src/io.rs | 43 +++--- protocol/src/lib.rs | 23 ++- protocol/tests/round_trips.rs | 19 +-- 6 files changed, 170 insertions(+), 115 deletions(-) diff --git a/protocol/benches/cipher_session.rs b/protocol/benches/cipher_session.rs index f9120ff..8283d99 100644 --- a/protocol/benches/cipher_session.rs +++ b/protocol/benches/cipher_session.rs @@ -6,7 +6,7 @@ extern crate test; use bip324::{ CipherSession, Handshake, HandshakeAuthentication, InboundCipher, Initialized, Network, - OutboundCipher, PacketType, ReceivedKey, Role, NUM_INITIAL_HANDSHAKE_BUFFER_BYTES, + OutboundCipher, PacketType, ReceivedKey, Role, }; use test::{black_box, Bencher}; @@ -45,11 +45,9 @@ fn create_cipher_session_pair() -> (CipherSession, CipherSession) { .send_version(&mut bob_version_buffer, None) .unwrap(); - let mut packet_buffer = vec![0u8; NUM_INITIAL_HANDSHAKE_BUFFER_BYTES]; - // Alice receives Bob's version. let alice = match alice_handshake - .receive_version(&bob_version_buffer, &mut packet_buffer) + .receive_version(&mut bob_version_buffer) .unwrap() { HandshakeAuthentication::Complete { cipher, .. } => cipher, @@ -58,7 +56,7 @@ fn create_cipher_session_pair() -> (CipherSession, CipherSession) { // Bob receives Alice's version. let bob = match bob_handshake - .receive_version(&alice_version_buffer, &mut packet_buffer) + .receive_version(&mut alice_version_buffer) .unwrap() { HandshakeAuthentication::Complete { cipher, .. } => cipher, diff --git a/protocol/fuzz/fuzz_targets/handshake.rs b/protocol/fuzz/fuzz_targets/handshake.rs index 09b9ee6..f9b61df 100644 --- a/protocol/fuzz/fuzz_targets/handshake.rs +++ b/protocol/fuzz/fuzz_targets/handshake.rs @@ -7,10 +7,7 @@ //! * The implementation should handle all inputs gracefully. #![no_main] -use bip324::{ - Handshake, HandshakeAuthentication, Initialized, Network, ReceivedKey, Role, - NUM_INITIAL_HANDSHAKE_BUFFER_BYTES, -}; +use bip324::{Handshake, HandshakeAuthentication, Initialized, Network, ReceivedKey, Role}; use libfuzzer_sys::fuzz_target; fuzz_target!(|data: &[u8]| { @@ -41,9 +38,8 @@ fuzz_target!(|data: &[u8]| { let handshake = handshake.send_version(&mut version_buffer, None).unwrap(); // Try to receive and authenticate the fuzzed garbage and version data. - let garbage_and_version = Vec::from(&data[64..]); - let mut packet_buffer = vec![0u8; NUM_INITIAL_HANDSHAKE_BUFFER_BYTES]; - match handshake.receive_version(&garbage_and_version, &mut packet_buffer) { + let mut garbage_and_version = Vec::from(&data[64..]); + match handshake.receive_version(&mut garbage_and_version) { Ok(HandshakeAuthentication::Complete { .. }) => { // Handshake completed successfully. // This should only happen with some very lucky random bytes. diff --git a/protocol/src/handshake.rs b/protocol/src/handshake.rs index 8976784..5b3fb98 100644 --- a/protocol/src/handshake.rs +++ b/protocol/src/handshake.rs @@ -20,14 +20,17 @@ use bitcoin::{ use rand::Rng; use crate::{ - CipherSession, Error, InboundCipher, OutboundCipher, PacketType, Role, SessionKeyMaterial, + CipherSession, Error, OutboundCipher, PacketType, Role, SessionKeyMaterial, NUM_ELLIGATOR_SWIFT_BYTES, NUM_GARBAGE_TERMINTOR_BYTES, NUM_LENGTH_BYTES, VERSION_CONTENT, }; -/// Initial buffer for decoy and version packets in the handshake. -/// The buffer may have to be expanded if a party is sending large -/// decoy packets. -pub const NUM_INITIAL_HANDSHAKE_BUFFER_BYTES: usize = 4096; +/// Initial buffer length hint to receive garbage, the garbage terminator, +/// and the version packet from the remote peer. This assumes no decoy packets +/// are sent (which is the default in Bitcoin Core) and no garbage bytes. +/// Calculated as: garbage_terminator (16 bytes) + encrypted_version_packet +/// where the version packet content is currently empty (0 bytes). +pub const NUM_INITIAL_BUFFER_BYTES_HINT: usize = + NUM_GARBAGE_TERMINTOR_BYTES + OutboundCipher::encryption_buffer_len(VERSION_CONTENT.len()); // Maximum number of garbage bytes before the terminator. const MAX_NUM_GARBAGE_BYTES: usize = 4095; @@ -60,6 +63,8 @@ pub struct SentVersion { cipher: CipherSession, remote_garbage_terminator: [u8; NUM_GARBAGE_TERMINTOR_BYTES], bytes_written: usize, + ciphertext_index: usize, + remote_garbage_authenticated: bool, } /// Success variants for receive_version. @@ -147,7 +152,22 @@ impl<'a> Handshake<'a, Initialized> { NUM_ELLIGATOR_SWIFT_BYTES + garbage.map(|g| g.len()).unwrap_or(0) } - /// Send local public key and optional garbage. + /// Send local public key and optional garbage to initiate the handshake. + /// + /// # Parameters + /// + /// * `garbage` - Optional garbage bytes to append after the public key. Limited to 4095 bytes. + /// * `output_buffer` - Buffer to write the key and garbage. Must have sufficient capacity + /// as calculated by `send_key_len()`. + /// + /// # Returns + /// + /// `Ok(Handshake)` - Ready to receive remote peer's key material. + /// + /// # Errors + /// + /// * `TooMuchGarbage` - Garbage exceeds 4095 byte limit. + /// * `BufferTooSmall` - Output buffer insufficient for key + garbage. pub fn send_key( mut self, garbage: Option<&'a [u8]>, @@ -200,7 +220,25 @@ impl<'a> Handshake<'a, SentKey> { self.state.bytes_written } - /// Process received key material. + /// Process the remote peer's public key and derive shared session secrets. + /// + /// This is the **second state transition** in the handshake process, moving from + /// `SentKey` to `ReceivedKey` state. The method performs ECDH key exchange using + /// the received remote public key and generates all cryptographic material needed + /// for the secure session. + /// + /// # Parameters + /// + /// * `their_key` - The remote peer's 64-byte ElligatorSwift encoded public key. + /// + /// # Returns + /// + /// `Ok(Handshake)` - Ready to send version packet with derived session keys. + /// + /// # Errors + /// + /// * `V1Protocol` - Remote peer is using the legacy V1 protocol. + /// * `SecretGeneration` - Failed to derive session keys from ECDH. pub fn receive_key( self, their_key: [u8; NUM_ELLIGATOR_SWIFT_BYTES], @@ -268,7 +306,27 @@ impl<'a> Handshake<'a, ReceivedKey> { len } - /// Send garbage terminator, optional decoys, and version packet. + /// Send garbage terminator, optional decoy packets, and version packet. + /// + /// This is the **third state transition** in the handshake process, moving from + /// `ReceivedKey` to `SentVersion` state. The method initiates encrypted communication + /// by sending the local garbage terminator followed by encrypted packets. + /// + /// # Parameters + /// + /// * `output_buffer` - Buffer to write terminator and encrypted packets. Must have + /// sufficient capacity as calculated by `send_version_len()`. + /// * `decoys` - Optional array of decoy packet contents to send before version packet + /// to help hide the shape of traffic. + /// + /// # Returns + /// + /// `Ok(Handshake)` - Ready to receive and authenticate remote peer's version. + /// + /// # Errors + /// + /// * `BufferTooSmall` - Output buffer insufficient for terminator + packets. + /// * `Decryption` - Cipher operation failed. pub fn send_version( self, output_buffer: &mut [u8], @@ -334,6 +392,8 @@ impl<'a> Handshake<'a, ReceivedKey> { cipher, remote_garbage_terminator, bytes_written, + ciphertext_index: 0, + remote_garbage_authenticated: false, }, }) } @@ -346,11 +406,30 @@ impl<'a> Handshake<'a, SentVersion> { self.state.bytes_written } - /// Authenticate garbage and version packet. + /// Authenticate remote peer's garbage, decoy packets, and version packet. + /// + /// This method is unique in the handshake process as it requires a **mutable** input buffer + /// to perform in-place decryption operations. The buffer contains everything after the 64 + /// byte public key received from the remote peer: optional garbage bytes, garbage terminator, + /// and encrypted packets (decoys and final version packet). + /// + /// The input buffer is mutable in the case because the caller generally doesn't care + /// about the decoy and version packets, and definitely doesn't want to deal with + /// allocating memory for them. + /// + /// # Parameters + /// + /// * `input_buffer` - **Mutable** buffer containing garbage + terminator + encrypted packets. + /// The buffer will be modified during in-place decryption operations. + /// + /// # Returns + /// + /// * `Complete { cipher, bytes_consumed }` - Handshake succeeded, secure session established. + /// * `NeedMoreData(handshake)` - Insufficient data, retry by extending the buffer. + /// ``` pub fn receive_version( mut self, - input_buffer: &[u8], - output_buffer: &mut [u8], + input_buffer: &mut [u8], ) -> Result, Error> { let (garbage, ciphertext) = match self.split_garbage(input_buffer) { Ok(split) => split, @@ -360,22 +439,25 @@ impl<'a> Handshake<'a, SentVersion> { Err(e) => return Err(e), }; - let mut aad = if garbage.is_empty() { + let mut aad = if garbage.is_empty() || self.state.remote_garbage_authenticated { None } else { Some(garbage) }; - let mut ciphertext_index = 0; - let mut found_version = false; // First packet authenticates remote garbage. // Continue through decoys until we find version packet. - while !found_version { - match self.decrypt_packet(&ciphertext[ciphertext_index..], output_buffer, aad) { + loop { + match self.decrypt_packet(&mut ciphertext[self.state.ciphertext_index..], aad) { Ok((packet_type, bytes_consumed)) => { - aad = None; - ciphertext_index += bytes_consumed; - found_version = matches!(packet_type, PacketType::Genuine); + if aad.is_some() { + aad = None; + self.state.remote_garbage_authenticated = true; + } + self.state.ciphertext_index += bytes_consumed; + if matches!(packet_type, PacketType::Genuine) { + break; + } } Err(Error::CiphertextTooSmall) => { return Ok(HandshakeAuthentication::NeedMoreData(self)) @@ -384,8 +466,9 @@ impl<'a> Handshake<'a, SentVersion> { } } - // Calculate total bytes consumed (garbage + terminator + packets) - let bytes_consumed = garbage.len() + NUM_GARBAGE_TERMINTOR_BYTES + ciphertext_index; + // Calculate total bytes consumed. + let bytes_consumed = + garbage.len() + NUM_GARBAGE_TERMINTOR_BYTES + self.state.ciphertext_index; Ok(HandshakeAuthentication::Complete { cipher: self.state.cipher, @@ -394,14 +477,16 @@ impl<'a> Handshake<'a, SentVersion> { } /// Split buffer on garbage terminator. - fn split_garbage<'b>(&self, buffer: &'b [u8]) -> Result<(&'b [u8], &'b [u8]), Error> { + fn split_garbage<'b>(&self, buffer: &'b mut [u8]) -> Result<(&'b [u8], &'b mut [u8]), Error> { let terminator = &self.state.remote_garbage_terminator; if let Some(index) = buffer .windows(terminator.len()) .position(|window| window == terminator) { - Ok((&buffer[..index], &buffer[index + terminator.len()..])) + let (garbage, rest) = buffer.split_at_mut(index); + let ciphertext = &mut rest[terminator.len()..]; + Ok((garbage, ciphertext)) } else if buffer.len() >= MAX_NUM_GARBAGE_BYTES + NUM_GARBAGE_TERMINTOR_BYTES { Err(Error::NoGarbageTerminator) } else { @@ -409,10 +494,10 @@ impl<'a> Handshake<'a, SentVersion> { } } + /// Decrypts in place, returns the packet type and number of bytes consumed from the ciphertext. fn decrypt_packet( &mut self, - ciphertext: &[u8], - output_buffer: &mut [u8], + ciphertext: &mut [u8], aad: Option<&[u8]>, ) -> Result<(PacketType, usize), Error> { if ciphertext.len() < NUM_LENGTH_BYTES { @@ -429,17 +514,8 @@ impl<'a> Handshake<'a, SentVersion> { return Err(Error::CiphertextTooSmall); } - // Check output buffer is large enough. - let plaintext_len = InboundCipher::decryption_buffer_len(packet_len); - if output_buffer.len() < plaintext_len { - return Err(Error::BufferTooSmall { - required_bytes: plaintext_len, - }); - } - - let packet_type = self.state.cipher.inbound().decrypt( - &ciphertext[NUM_LENGTH_BYTES..NUM_LENGTH_BYTES + packet_len], - &mut output_buffer[..plaintext_len], + let (packet_type, _) = self.state.cipher.inbound().decrypt_in_place( + &mut ciphertext[NUM_LENGTH_BYTES..NUM_LENGTH_BYTES + packet_len], aad, )?; @@ -512,12 +588,10 @@ mod tests { .send_version(&mut resp_version_buffer, Some(&resp_decoys)) .unwrap(); - let mut packet_buffer = vec![0u8; NUM_INITIAL_HANDSHAKE_BUFFER_BYTES]; - // Initiator receives responder's garbage, decoys, and version. - let full_resp_message = [&responder_garbage[..], &resp_version_buffer[..]].concat(); + let mut full_resp_message = [&responder_garbage[..], &resp_version_buffer[..]].concat(); match init_handshake - .receive_version(&full_resp_message, &mut packet_buffer) + .receive_version(&mut full_resp_message) .unwrap() { HandshakeAuthentication::Complete { bytes_consumed, .. } => { @@ -527,9 +601,9 @@ mod tests { } // Responder receives initiator's garbage, decoys, and version. - let full_init_message = [&initiator_garbage[..], &init_version_buffer[..]].concat(); + let mut full_init_message = [&initiator_garbage[..], &init_version_buffer[..]].concat(); match resp_handshake - .receive_version(&full_init_message, &mut packet_buffer) + .receive_version(&mut full_init_message) .unwrap() { HandshakeAuthentication::Complete { bytes_consumed, .. } => { @@ -601,13 +675,10 @@ mod tests { .send_version(&mut resp_version_buffer, None) .unwrap(); - let mut packet_buffer = vec![0u8; NUM_INITIAL_HANDSHAKE_BUFFER_BYTES]; - // Feed data in very small chunks to trigger NeedMoreData. let partial_data_1 = &init_version_buffer[..1]; - let returned_handshake = match resp_handshake - .receive_version(partial_data_1, &mut packet_buffer) - .unwrap() + let mut partial_data_1 = partial_data_1.to_vec(); + let returned_handshake = match resp_handshake.receive_version(&mut partial_data_1).unwrap() { HandshakeAuthentication::NeedMoreData(handshake) => handshake, HandshakeAuthentication::Complete { .. } => { @@ -617,8 +688,9 @@ mod tests { // Feed a bit more data - still probably not enough. let partial_data_2 = &init_version_buffer[..5]; + let mut partial_data_2 = partial_data_2.to_vec(); let returned_handshake = match returned_handshake - .receive_version(partial_data_2, &mut packet_buffer) + .receive_version(&mut partial_data_2) .unwrap() { HandshakeAuthentication::NeedMoreData(handshake) => handshake, @@ -628,10 +700,8 @@ mod tests { }; // Now provide the complete data. - match returned_handshake - .receive_version(&init_version_buffer, &mut packet_buffer) - .unwrap() - { + let mut full_data = init_version_buffer.clone(); + match returned_handshake.receive_version(&mut full_data).unwrap() { HandshakeAuthentication::Complete { bytes_consumed, .. } => { assert_eq!(bytes_consumed, init_version_buffer.len()); } @@ -657,13 +727,13 @@ mod tests { let handshake = handshake.send_version(&mut version_buffer, None).unwrap(); // Test with a buffer that is too long (should fail to find terminator) - let test_buffer = vec![0; MAX_NUM_GARBAGE_BYTES + NUM_GARBAGE_TERMINTOR_BYTES]; - let result = handshake.split_garbage(&test_buffer); + let mut test_buffer = vec![0; MAX_NUM_GARBAGE_BYTES + NUM_GARBAGE_TERMINTOR_BYTES]; + let result = handshake.split_garbage(&mut test_buffer); assert!(matches!(result, Err(Error::NoGarbageTerminator))); // Test with a buffer that's just short of the required length - let short_buffer = vec![0; MAX_NUM_GARBAGE_BYTES + NUM_GARBAGE_TERMINTOR_BYTES - 1]; - let result = handshake.split_garbage(&short_buffer); + let mut short_buffer = vec![0; MAX_NUM_GARBAGE_BYTES + NUM_GARBAGE_TERMINTOR_BYTES - 1]; + let result = handshake.split_garbage(&mut short_buffer); assert!(matches!(result, Err(Error::CiphertextTooSmall))); } } diff --git a/protocol/src/io.rs b/protocol/src/io.rs index 3498a98..bafc90f 100644 --- a/protocol/src/io.rs +++ b/protocol/src/io.rs @@ -22,7 +22,7 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use crate::{ handshake::{self, HandshakeAuthentication}, Error, Handshake, InboundCipher, OutboundCipher, PacketType, Role, NUM_ELLIGATOR_SWIFT_BYTES, - NUM_INITIAL_HANDSHAKE_BUFFER_BYTES, + NUM_INITIAL_BUFFER_BYTES_HINT, }; /// A decrypted BIP324 payload with its packet type. @@ -200,21 +200,31 @@ impl AsyncProtocol { writer.flush().await?; // Receive and authenticate remote garbage and version. - let mut remote_garbage_and_version_buffer = - Vec::with_capacity(NUM_INITIAL_HANDSHAKE_BUFFER_BYTES); - let mut packet_buffer = vec![0u8; NUM_INITIAL_HANDSHAKE_BUFFER_BYTES]; + let mut remote_garbage_and_version_buffer = vec![0u8; NUM_INITIAL_BUFFER_BYTES_HINT]; + let mut bytes_read_so_far = 0; let mut handshake = handshake; loop { - let mut temp_buffer = [0u8; NUM_INITIAL_HANDSHAKE_BUFFER_BYTES]; - match reader.read(&mut temp_buffer).await { - Ok(0) => continue, + match reader + .read(&mut remote_garbage_and_version_buffer[bytes_read_so_far..]) + .await + { + Ok(0) => { + // EOF - remote closed connection. + return Err(ProtocolError::Io( + std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "Remote peer closed connection during handshake", + ), + ProtocolFailureSuggestion::RetryV1, + )); + } Ok(bytes_read) => { - remote_garbage_and_version_buffer.extend_from_slice(&temp_buffer[..bytes_read]); + bytes_read_so_far += bytes_read; - handshake = match handshake - .receive_version(&remote_garbage_and_version_buffer, &mut packet_buffer) - { + handshake = match handshake.receive_version( + &mut remote_garbage_and_version_buffer[..bytes_read_so_far], + ) { Ok(HandshakeAuthentication::Complete { cipher, .. }) => { let (inbound_cipher, outbound_cipher) = cipher.into_split(); return Ok(Self { @@ -225,12 +235,11 @@ impl AsyncProtocol { writer: AsyncProtocolWriter { outbound_cipher }, }); } - Ok(HandshakeAuthentication::NeedMoreData(handshake)) => handshake, - Err(Error::BufferTooSmall { required_bytes }) => { - packet_buffer.resize(required_bytes, 0); - return Err(ProtocolError::Internal(Error::BufferTooSmall { - required_bytes, - })); + Ok(HandshakeAuthentication::NeedMoreData(handshake)) => { + // Need more data - extend buffer for next read. + remote_garbage_and_version_buffer + .resize(remote_garbage_and_version_buffer.len() + 1024, 0); + handshake } Err(e) => return Err(ProtocolError::Internal(e)), }; diff --git a/protocol/src/lib.rs b/protocol/src/lib.rs index ae2dc8d..645b548 100644 --- a/protocol/src/lib.rs +++ b/protocol/src/lib.rs @@ -41,7 +41,7 @@ pub use bitcoin::Network; pub use handshake::{ Handshake, HandshakeAuthentication, Initialized, ReceivedKey, SentKey, SentVersion, - NUM_INITIAL_HANDSHAKE_BUFFER_BYTES, + NUM_INITIAL_BUFFER_BYTES_HINT, }; // Re-exports from io module (async I/O types for backwards compatibility) #[cfg(any(feature = "futures", feature = "tokio"))] @@ -437,7 +437,7 @@ pub struct OutboundCipher { impl OutboundCipher { /// Calculate the required encryption buffer length for given plaintext length. - pub fn encryption_buffer_len(plaintext_len: usize) -> usize { + pub const fn encryption_buffer_len(plaintext_len: usize) -> usize { plaintext_len + NUM_PACKET_OVERHEAD_BYTES } @@ -577,7 +577,6 @@ impl CipherSession { #[cfg(all(test, feature = "std"))] mod tests { - use crate::handshake::NUM_INITIAL_HANDSHAKE_BUFFER_BYTES; use super::*; use bitcoin::secp256k1::ellswift::{ElligatorSwift, ElligatorSwiftParty}; @@ -875,17 +874,14 @@ mod tests { .unwrap(); let resp_version_len = resp_handshake.bytes_written(); - // Complete handshakes - let mut packet_buffer = vec![0u8; NUM_INITIAL_HANDSHAKE_BUFFER_BYTES]; - // Initiator receives responder's version - let full_resp_message = [ + let mut full_resp_message = [ &responder_garbage[..], &resp_version_buffer[..resp_version_len], ] .concat(); let mut alice = match init_handshake - .receive_version(&full_resp_message, &mut packet_buffer) + .receive_version(&mut full_resp_message) .unwrap() { HandshakeAuthentication::Complete { cipher, .. } => cipher, @@ -893,13 +889,13 @@ mod tests { }; // Responder receives initiator's version - let full_init_message = [ + let mut full_init_message = [ &initiator_garbage[..], &init_version_buffer[..init_version_len], ] .concat(); let mut bob = match resp_handshake - .receive_version(&full_init_message, &mut packet_buffer) + .receive_version(&mut full_init_message) .unwrap() { HandshakeAuthentication::Complete { cipher, .. } => cipher, @@ -1006,12 +1002,9 @@ mod tests { .send_version(&mut resp_version_buffer, None) .unwrap(); - // Complete handshakes - let mut packet_buffer = vec![0u8; NUM_INITIAL_HANDSHAKE_BUFFER_BYTES]; - // Initiator receives responder's version let mut alice = match init_handshake - .receive_version(&resp_version_buffer, &mut packet_buffer) + .receive_version(&mut resp_version_buffer) .unwrap() { HandshakeAuthentication::Complete { cipher, .. } => cipher, @@ -1020,7 +1013,7 @@ mod tests { // Responder receives initiator's version let mut bob = match resp_handshake - .receive_version(&init_version_buffer, &mut packet_buffer) + .receive_version(&mut init_version_buffer) .unwrap() { HandshakeAuthentication::Complete { cipher, .. } => cipher, diff --git a/protocol/tests/round_trips.rs b/protocol/tests/round_trips.rs index a1e5793..36340ff 100644 --- a/protocol/tests/round_trips.rs +++ b/protocol/tests/round_trips.rs @@ -5,10 +5,7 @@ const PORT: u16 = 18444; #[test] #[cfg(feature = "std")] fn hello_world_happy_path() { - use bip324::{ - Handshake, HandshakeAuthentication, Initialized, PacketType, ReceivedKey, Role, - NUM_INITIAL_HANDSHAKE_BUFFER_BYTES, - }; + use bip324::{Handshake, HandshakeAuthentication, Initialized, PacketType, ReceivedKey, Role}; use bitcoin::Network; // Create initiator handshake @@ -47,12 +44,9 @@ fn hello_world_happy_path() { .send_version(&mut resp_version_buffer, None) .unwrap(); - // Complete handshakes - let mut packet_buffer = vec![0u8; NUM_INITIAL_HANDSHAKE_BUFFER_BYTES]; - // Initiator receives responder's version let mut alice = match init_handshake - .receive_version(&resp_version_buffer, &mut packet_buffer) + .receive_version(&mut resp_version_buffer) .unwrap() { HandshakeAuthentication::Complete { cipher, .. } => cipher, @@ -61,7 +55,7 @@ fn hello_world_happy_path() { // Responder receives initiator's version let mut bob = match resp_handshake - .receive_version(&init_version_buffer, &mut packet_buffer) + .receive_version(&mut init_version_buffer) .unwrap() { HandshakeAuthentication::Complete { cipher, .. } => cipher, @@ -132,7 +126,6 @@ fn regtest_handshake() { use bip324::{ serde::{deserialize, serialize, NetworkMessage}, Handshake, HandshakeAuthentication, Initialized, PacketType, ReceivedKey, - NUM_INITIAL_HANDSHAKE_BUFFER_BYTES, }; use bitcoin::p2p::{message_network::VersionMessage, Address, ServiceFlags}; let bitcoind = regtest_process(TransportVersion::V2); @@ -174,12 +167,8 @@ fn regtest_handshake() { let size = stream.read(&mut max_response).unwrap(); let response = &mut max_response[..size]; println!("Authenticating the handshake"); - let mut packet_buffer = vec![0u8; NUM_INITIAL_HANDSHAKE_BUFFER_BYTES]; - let cipher_session = match handshake - .receive_version(response, &mut packet_buffer) - .unwrap() - { + let cipher_session = match handshake.receive_version(response).unwrap() { HandshakeAuthentication::Complete { cipher, .. } => { println!("Finalizing the handshake"); cipher From cc9251696a0232ba5f7e44f87220cf13f08b4eb1 Mon Sep 17 00:00:00 2001 From: Nick Johnson Date: Fri, 27 Jun 2025 10:37:57 -0700 Subject: [PATCH 4/5] feat: minimize lifetime of local garbage --- protocol/src/handshake.rs | 61 +++++++++++++++++++-------------------- 1 file changed, 30 insertions(+), 31 deletions(-) diff --git a/protocol/src/handshake.rs b/protocol/src/handshake.rs index 5b3fb98..20634ed 100644 --- a/protocol/src/handshake.rs +++ b/protocol/src/handshake.rs @@ -47,15 +47,17 @@ pub struct Initialized { } /// **Second state** after sending the local public key. -pub struct SentKey { +pub struct SentKey<'a> { point: EcdhPoint, bytes_written: usize, + local_garbage: Option<&'a [u8]>, } /// **Third state** after receiving the remote's public key and /// generating the shared secret materials for the session. -pub struct ReceivedKey { +pub struct ReceivedKey<'a> { session_keys: SessionKeyMaterial, + local_garbage: Option<&'a [u8]>, } /// **Fourth state** after sending the version packet. @@ -68,14 +70,14 @@ pub struct SentVersion { } /// Success variants for receive_version. -pub enum HandshakeAuthentication<'a> { +pub enum HandshakeAuthentication { /// Successfully completed. Complete { cipher: CipherSession, bytes_consumed: usize, }, /// Need more data - returns handshake for caller to retry with more ciphertext. - NeedMoreData(Handshake<'a, SentVersion>), + NeedMoreData(Handshake), } /// Handshake state-machine to establish the secret material in the communication channel. @@ -87,19 +89,17 @@ pub enum HandshakeAuthentication<'a> { /// 3. `ReceivedKey` - After receiving remote's public key. /// 4. `SentVersion` - After sending local garbage terminator and version packet. /// 5. Complete - After receiving and authenticating remote's garbage, garbage terminator, decoy packets, and version packet. -pub struct Handshake<'a, State> { +pub struct Handshake { /// Bitcoin network both peers are operating on. network: Network, /// Local role in the handshake, initiator or responder. role: Role, - /// Optional local garbage bytes to send along in handshake. - garbage: Option<&'a [u8]>, /// State-specific data. state: State, } // Methods available in all states -impl<'a, State> Handshake<'a, State> { +impl Handshake { /// Get the network this handshake is operating on. pub fn network(&self) -> Network { self.network @@ -112,7 +112,7 @@ impl<'a, State> Handshake<'a, State> { } // Initialized state implementation -impl<'a> Handshake<'a, Initialized> { +impl Handshake { /// Initialize a V2 transport handshake with a remote peer. #[cfg(feature = "std")] pub fn new(network: Network, role: Role) -> Result { @@ -142,7 +142,6 @@ impl<'a> Handshake<'a, Initialized> { Ok(Handshake { network, role, - garbage: None, state: Initialized { point }, }) } @@ -168,11 +167,11 @@ impl<'a> Handshake<'a, Initialized> { /// /// * `TooMuchGarbage` - Garbage exceeds 4095 byte limit. /// * `BufferTooSmall` - Output buffer insufficient for key + garbage. - pub fn send_key( - mut self, + pub fn send_key<'a>( + self, garbage: Option<&'a [u8]>, output_buffer: &mut [u8], - ) -> Result, Error> { + ) -> Result>, Error> { // Validate garbage length if let Some(g) = garbage { if g.len() > MAX_NUM_GARBAGE_BYTES { @@ -198,23 +197,20 @@ impl<'a> Handshake<'a, Initialized> { written += g.len(); } - // Store garbage for later use. - self.garbage = garbage; - Ok(Handshake { network: self.network, role: self.role, - garbage: self.garbage, state: SentKey { point: self.state.point, bytes_written: written, + local_garbage: garbage, }, }) } } // SentKey state implementation -impl<'a> Handshake<'a, SentKey> { +impl<'a> Handshake> { /// Get how many bytes were written by send_key(). pub fn bytes_written(&self) -> usize { self.state.bytes_written @@ -242,7 +238,7 @@ impl<'a> Handshake<'a, SentKey> { pub fn receive_key( self, their_key: [u8; NUM_ELLIGATOR_SWIFT_BYTES], - ) -> Result, Error> { + ) -> Result>, Error> { let their_ellswift = ElligatorSwift::from_array(their_key); // Check for V1 protocol magic bytes @@ -283,14 +279,16 @@ impl<'a> Handshake<'a, SentKey> { Ok(Handshake { network: self.network, role: self.role, - garbage: self.garbage, - state: ReceivedKey { session_keys }, + state: ReceivedKey { + session_keys, + local_garbage: self.state.local_garbage, + }, }) } } // ReceivedKey state implementation -impl<'a> Handshake<'a, ReceivedKey> { +impl<'a> Handshake> { /// Calculate how many bytes send_version() will write to buffer. pub fn send_version_len(decoys: Option<&[&[u8]]>) -> usize { let mut len = NUM_GARBAGE_TERMINTOR_BYTES @@ -331,7 +329,7 @@ impl<'a> Handshake<'a, ReceivedKey> { self, output_buffer: &mut [u8], decoys: Option<&[&[u8]]>, - ) -> Result, Error> { + ) -> Result, Error> { let required_len = Self::send_version_len(decoys); if output_buffer.len() < required_len { return Err(Error::BufferTooSmall { @@ -358,7 +356,7 @@ impl<'a> Handshake<'a, ReceivedKey> { let mut bytes_written = NUM_GARBAGE_TERMINTOR_BYTES; // Local garbage is authenticated in first packet no // matter if it is a decoy or genuine. - let mut aad = self.garbage; + let mut aad = self.state.local_garbage; if let Some(decoys) = decoys { for decoy in decoys { @@ -387,7 +385,6 @@ impl<'a> Handshake<'a, ReceivedKey> { Ok(Handshake { network: self.network, role: self.role, - garbage: self.garbage, state: SentVersion { cipher, remote_garbage_terminator, @@ -400,7 +397,7 @@ impl<'a> Handshake<'a, ReceivedKey> { } // SentVersion state implementation -impl<'a> Handshake<'a, SentVersion> { +impl Handshake { /// Get how many bytes were written by send_version(). pub fn bytes_written(&self) -> usize { self.state.bytes_written @@ -430,7 +427,7 @@ impl<'a> Handshake<'a, SentVersion> { pub fn receive_version( mut self, input_buffer: &mut [u8], - ) -> Result, Error> { + ) -> Result { let (garbage, ciphertext) = match self.split_garbage(input_buffer) { Ok(split) => split, Err(Error::CiphertextTooSmall) => { @@ -576,14 +573,14 @@ mod tests { // Initiator sends decoys and version. let mut init_version_buffer = - vec![0u8; Handshake::::send_version_len(Some(&init_decoys))]; + vec![0u8; Handshake::>::send_version_len(Some(&init_decoys))]; let init_handshake = init_handshake .send_version(&mut init_version_buffer, Some(&init_decoys)) .unwrap(); // Responder sends decoys and version. let mut resp_version_buffer = - vec![0u8; Handshake::::send_version_len(Some(&resp_decoys))]; + vec![0u8; Handshake::>::send_version_len(Some(&resp_decoys))]; let resp_handshake = resp_handshake .send_version(&mut resp_version_buffer, Some(&resp_decoys)) .unwrap(); @@ -665,12 +662,14 @@ mod tests { .receive_key(init_buffer[..64].try_into().unwrap()) .unwrap(); - let mut init_version_buffer = vec![0u8; Handshake::::send_version_len(None)]; + let mut init_version_buffer = + vec![0u8; Handshake::>::send_version_len(None)]; let _init_handshake = init_handshake .send_version(&mut init_version_buffer, None) .unwrap(); - let mut resp_version_buffer = vec![0u8; Handshake::::send_version_len(None)]; + let mut resp_version_buffer = + vec![0u8; Handshake::>::send_version_len(None)]; let resp_handshake = resp_handshake .send_version(&mut resp_version_buffer, None) .unwrap(); From a5a4e71c7cca170c5c2aee96edb6853ea5b1d505 Mon Sep 17 00:00:00 2001 From: Nick Johnson Date: Wed, 2 Jul 2025 15:40:26 -0700 Subject: [PATCH 5/5] feat!: split the last handshake step into two The real tricky part of the handshake is reading the unknown number of garbage bytes from the remote peer. That was hard to tease apart when combined with the authenticating the garbage with the following decoy or version packet, so split into two steps. --- justfile | 4 +- protocol/benches/cipher_session.rs | 72 ++- protocol/fuzz/Cargo.toml | 18 +- protocol/fuzz/fuzz_targets/handshake.rs | 56 -- protocol/fuzz/fuzz_targets/receive_garbage.rs | 63 ++ protocol/fuzz/fuzz_targets/receive_key.rs | 43 ++ protocol/fuzz/fuzz_targets/receive_version.rs | 70 +++ protocol/src/handshake.rs | 538 ++++++++++++------ protocol/src/io.rs | 134 +++-- protocol/src/lib.rs | 232 +------- protocol/tests/round_trips.rs | 131 ++++- 11 files changed, 784 insertions(+), 577 deletions(-) delete mode 100644 protocol/fuzz/fuzz_targets/handshake.rs create mode 100644 protocol/fuzz/fuzz_targets/receive_garbage.rs create mode 100644 protocol/fuzz/fuzz_targets/receive_key.rs create mode 100644 protocol/fuzz/fuzz_targets/receive_version.rs diff --git a/justfile b/justfile index ae02f82..fe2c160 100644 --- a/justfile +++ b/justfile @@ -75,8 +75,8 @@ _default: bench: cargo +{{NIGHTLY_TOOLCHAIN}} bench --package bip324 --bench cipher_session -# Run fuzz test: handshake. -@fuzz target="handshake" time="60": +# Run fuzz test: receive_key, receive_garbage, receive_version. +@fuzz target="receive_garbage" time="60": cargo install cargo-fuzz@0.12.0 cd protocol && cargo +{{NIGHTLY_TOOLCHAIN}} fuzz run {{target}} -- -max_total_time={{time}} diff --git a/protocol/benches/cipher_session.rs b/protocol/benches/cipher_session.rs index 8283d99..f0f3be3 100644 --- a/protocol/benches/cipher_session.rs +++ b/protocol/benches/cipher_session.rs @@ -5,8 +5,8 @@ extern crate test; use bip324::{ - CipherSession, Handshake, HandshakeAuthentication, InboundCipher, Initialized, Network, - OutboundCipher, PacketType, ReceivedKey, Role, + CipherSession, GarbageResult, Handshake, InboundCipher, Initialized, Network, OutboundCipher, + PacketType, ReceivedKey, Role, VersionResult, NUM_LENGTH_BYTES, }; use test::{black_box, Bencher}; @@ -46,21 +46,53 @@ fn create_cipher_session_pair() -> (CipherSession, CipherSession) { .unwrap(); // Alice receives Bob's version. - let alice = match alice_handshake - .receive_version(&mut bob_version_buffer) + // First handle Bob's garbage terminator + let (mut alice_handshake, consumed) = match alice_handshake + .receive_garbage(&bob_version_buffer) .unwrap() { - HandshakeAuthentication::Complete { cipher, .. } => cipher, - HandshakeAuthentication::NeedMoreData(_) => panic!("Should have completed"), + GarbageResult::FoundGarbage { + handshake, + consumed_bytes, + } => (handshake, consumed_bytes), + GarbageResult::NeedMoreData(_) => panic!("Should have found garbage terminator"), + }; + + // Process Bob's version packet + let remaining = &bob_version_buffer[consumed..]; + let packet_len = alice_handshake + .decrypt_packet_len(remaining[..NUM_LENGTH_BYTES].try_into().unwrap()) + .unwrap(); + let mut packet = remaining[NUM_LENGTH_BYTES..NUM_LENGTH_BYTES + packet_len].to_vec(); + + let alice = match alice_handshake.receive_version(&mut packet).unwrap() { + VersionResult::Complete { cipher } => cipher, + VersionResult::Decoy(_) => panic!("Should have completed"), }; // Bob receives Alice's version. - let bob = match bob_handshake - .receive_version(&mut alice_version_buffer) + // First handle Alice's garbage terminator + let (mut bob_handshake, consumed) = match bob_handshake + .receive_garbage(&alice_version_buffer) .unwrap() { - HandshakeAuthentication::Complete { cipher, .. } => cipher, - HandshakeAuthentication::NeedMoreData(_) => panic!("Should have completed"), + GarbageResult::FoundGarbage { + handshake, + consumed_bytes, + } => (handshake, consumed_bytes), + GarbageResult::NeedMoreData(_) => panic!("Should have found garbage terminator"), + }; + + // Process Alice's version packet + let remaining = &alice_version_buffer[consumed..]; + let packet_len = bob_handshake + .decrypt_packet_len(remaining[..NUM_LENGTH_BYTES].try_into().unwrap()) + .unwrap(); + let mut packet = remaining[NUM_LENGTH_BYTES..NUM_LENGTH_BYTES + packet_len].to_vec(); + + let bob = match bob_handshake.receive_version(&mut packet).unwrap() { + VersionResult::Complete { cipher } => cipher, + VersionResult::Decoy(_) => panic!("Should have completed"), }; (alice, bob) @@ -85,16 +117,16 @@ fn bench_round_trip_small_packet(b: &mut Bencher) { ) .unwrap(); - // Decrypt the length from first 3 bytes (real-world step). - let packet_length = bob - .inbound() - .decrypt_packet_len(black_box(encrypted[0..3].try_into().unwrap())); + // Decrypt the length from first NUM_LENGTH_BYTES bytes (real-world step). + let packet_length = bob.inbound().decrypt_packet_len(black_box( + encrypted[0..NUM_LENGTH_BYTES].try_into().unwrap(), + )); // Decrypt the payload using the decrypted length. let mut decrypted = vec![0u8; InboundCipher::decryption_buffer_len(packet_length)]; bob.inbound() .decrypt( - black_box(&encrypted[3..3 + packet_length]), + black_box(&encrypted[NUM_LENGTH_BYTES..NUM_LENGTH_BYTES + packet_length]), &mut decrypted, None, ) @@ -124,16 +156,16 @@ fn bench_round_trip_large_packet(b: &mut Bencher) { ) .unwrap(); - // Decrypt the length from first 3 bytes (real-world step). - let packet_length = bob - .inbound() - .decrypt_packet_len(black_box(encrypted[0..3].try_into().unwrap())); + // Decrypt the length from first NUM_LENGTH_BYTES bytes (real-world step). + let packet_length = bob.inbound().decrypt_packet_len(black_box( + encrypted[0..NUM_LENGTH_BYTES].try_into().unwrap(), + )); // Decrypt the payload using the decrypted length. let mut decrypted = vec![0u8; InboundCipher::decryption_buffer_len(packet_length)]; bob.inbound() .decrypt( - black_box(&encrypted[3..3 + packet_length]), + black_box(&encrypted[NUM_LENGTH_BYTES..NUM_LENGTH_BYTES + packet_length]), &mut decrypted, None, ) diff --git a/protocol/fuzz/Cargo.toml b/protocol/fuzz/Cargo.toml index f0531e5..e1d6c35 100644 --- a/protocol/fuzz/Cargo.toml +++ b/protocol/fuzz/Cargo.toml @@ -12,8 +12,22 @@ libfuzzer-sys = "0.4" bip324 = { path = ".." } [[bin]] -name = "handshake" -path = "fuzz_targets/handshake.rs" +name = "receive_garbage" +path = "fuzz_targets/receive_garbage.rs" +test = false +doc = false +bench = false + +[[bin]] +name = "receive_version" +path = "fuzz_targets/receive_version.rs" +test = false +doc = false +bench = false + +[[bin]] +name = "receive_key" +path = "fuzz_targets/receive_key.rs" test = false doc = false bench = false diff --git a/protocol/fuzz/fuzz_targets/handshake.rs b/protocol/fuzz/fuzz_targets/handshake.rs deleted file mode 100644 index f9b61df..0000000 --- a/protocol/fuzz/fuzz_targets/handshake.rs +++ /dev/null @@ -1,56 +0,0 @@ -// SPDX-License-Identifier: CC0-1.0 - -//! ## Expected Outcomes -//! -//! * Most runs will fail with invalid EC points or handshake failures. -//! * No panics, crashes, or memory safety issues should occur. -//! * The implementation should handle all inputs gracefully. - -#![no_main] -use bip324::{Handshake, HandshakeAuthentication, Initialized, Network, ReceivedKey, Role}; -use libfuzzer_sys::fuzz_target; - -fuzz_target!(|data: &[u8]| { - // Skip if data is too small for a meaningful test. - // We need at least 64 bytes for the public key and at - // least 30 more for some interesting garbage, decoy, version bytes. - if data.len() < 100 { - return; - } - - // Initiator side of the handshake. - let handshake = Handshake::::new(Network::Bitcoin, Role::Initiator).unwrap(); - let mut initiator_pubkey = vec![0u8; Handshake::::send_key_len(None)]; - let handshake = handshake.send_key(None, &mut initiator_pubkey).unwrap(); - - // Use the first 64 bytes of fuzz data as the responder's public key. - let mut fuzzed_responder_pubkey = [0u8; 64]; - fuzzed_responder_pubkey.copy_from_slice(&data[..64]); - - // Attempt to receive the fuzzed key. - let handshake = match handshake.receive_key(fuzzed_responder_pubkey) { - Ok(h) => h, - Err(_) => return, // Invalid key rejected successfully. - }; - - // Send version message just to move the state of the handshake along. - let mut version_buffer = vec![0u8; Handshake::::send_version_len(None)]; - let handshake = handshake.send_version(&mut version_buffer, None).unwrap(); - - // Try to receive and authenticate the fuzzed garbage and version data. - let mut garbage_and_version = Vec::from(&data[64..]); - match handshake.receive_version(&mut garbage_and_version) { - Ok(HandshakeAuthentication::Complete { .. }) => { - // Handshake completed successfully. - // This should only happen with some very lucky random bytes. - } - Ok(HandshakeAuthentication::NeedMoreData(_)) => { - // Handshake needs more ciphertext. - // This is an expected outcome for fuzzed inputs. - } - Err(_) => { - // Authentication or parsing failed. - // This is an expected outcome for fuzzed inputs. - } - } -}); diff --git a/protocol/fuzz/fuzz_targets/receive_garbage.rs b/protocol/fuzz/fuzz_targets/receive_garbage.rs new file mode 100644 index 0000000..5e6a9ae --- /dev/null +++ b/protocol/fuzz/fuzz_targets/receive_garbage.rs @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: CC0-1.0 + +//! Fuzz test for the receive_garbage function. +//! +//! This focused test fuzzes only the garbage terminator detection logic, +//! which is more effective than trying to fuzz the entire handshake. + +#![no_main] +use bip324::{GarbageResult, Handshake, Initialized, Network, ReceivedKey, Role}; +use libfuzzer_sys::fuzz_target; + +fuzz_target!(|data: &[u8]| { + // Set up a valid handshake in the SentVersion state + let initiator = Handshake::::new(Network::Bitcoin, Role::Initiator).unwrap(); + let mut initiator_key = vec![0u8; Handshake::::send_key_len(None)]; + let initiator = initiator.send_key(None, &mut initiator_key).unwrap(); + + let responder = Handshake::::new(Network::Bitcoin, Role::Responder).unwrap(); + let mut responder_key = vec![0u8; Handshake::::send_key_len(None)]; + let responder = responder.send_key(None, &mut responder_key).unwrap(); + + // Exchange keys using real keys to get valid ECDH shared secrets + let initiator = initiator + .receive_key(responder_key[..64].try_into().unwrap()) + .unwrap(); + let _responder = responder + .receive_key(initiator_key[..64].try_into().unwrap()) + .unwrap(); + + // Send version to reach SentVersion state + let mut initiator_version = vec![0u8; Handshake::::send_version_len(None)]; + let initiator = initiator + .send_version(&mut initiator_version, None) + .unwrap(); + + // Now fuzz the receive_garbage function with arbitrary data + match initiator.receive_garbage(data) { + Ok(GarbageResult::FoundGarbage { + handshake: _, + consumed_bytes, + }) => { + // Successfully found garbage terminator + // Verify consumed_bytes is reasonable + assert!(consumed_bytes <= data.len()); + assert!(consumed_bytes >= 16); // At least the terminator size + + // The garbage should be everything before the terminator + let garbage_len = consumed_bytes - 16; + assert!(garbage_len <= 4095); // Max garbage size + } + Ok(GarbageResult::NeedMoreData(_)) => { + // Need more data - valid outcome for short inputs + // This should happen when: + // 1. Buffer is too short to contain terminator + // 2. Buffer doesn't contain the terminator yet + } + Err(_) => { + // Error parsing garbage - valid outcome + // This should happen when: + // 1. No terminator found within max garbage size + } + } +}); diff --git a/protocol/fuzz/fuzz_targets/receive_key.rs b/protocol/fuzz/fuzz_targets/receive_key.rs new file mode 100644 index 0000000..9301c23 --- /dev/null +++ b/protocol/fuzz/fuzz_targets/receive_key.rs @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: CC0-1.0 + +//! Fuzz test for the receive_key function. +//! +//! This focused test fuzzes the elliptic curve point validation and ECDH logic. + +#![no_main] +use bip324::{Handshake, Initialized, Network, Role}; +use libfuzzer_sys::fuzz_target; + +fuzz_target!(|data: &[u8]| { + // Skip if data is not exactly 64 bytes + if data.len() != 64 { + return; + } + + // Set up a handshake in the SentKey state + let handshake = Handshake::::new(Network::Bitcoin, Role::Initiator).unwrap(); + let mut key_buffer = vec![0u8; Handshake::::send_key_len(None)]; + let handshake = handshake.send_key(None, &mut key_buffer).unwrap(); + + // Fuzz the receive_key function with arbitrary 64-byte data + let mut key_bytes = [0u8; 64]; + key_bytes.copy_from_slice(data); + + match handshake.receive_key(key_bytes) { + Ok(_handshake) => { + // Successfully processed the key + // This means: + // 1. The 64 bytes represent a valid ElligatorSwift encoding + // 2. The ECDH operation succeeded + // 3. The key derivation worked + // 4. It's not the V1 protocol magic bytes + } + Err(_) => { + // Failed to process the key + // This could be: + // 1. Invalid ElligatorSwift encoding + // 2. V1 protocol detected (first 4 bytes match network magic) + // 3. ECDH or key derivation failure + } + } +}); diff --git a/protocol/fuzz/fuzz_targets/receive_version.rs b/protocol/fuzz/fuzz_targets/receive_version.rs new file mode 100644 index 0000000..bbb23fa --- /dev/null +++ b/protocol/fuzz/fuzz_targets/receive_version.rs @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: CC0-1.0 + +//! Fuzz test for the receive_version function. +//! +//! This focused test fuzzes only the version packet decryption logic. + +#![no_main] +use bip324::{GarbageResult, Handshake, Initialized, Network, ReceivedKey, Role, VersionResult}; +use libfuzzer_sys::fuzz_target; + +fuzz_target!(|data: &[u8]| { + // Skip if data is too small + if data.is_empty() { + return; + } + + // Set up a valid handshake in the ReceivedGarbage state + let initiator = Handshake::::new(Network::Bitcoin, Role::Initiator).unwrap(); + let mut initiator_key = vec![0u8; Handshake::::send_key_len(None)]; + let initiator = initiator.send_key(None, &mut initiator_key).unwrap(); + + let responder = Handshake::::new(Network::Bitcoin, Role::Responder).unwrap(); + let mut responder_key = vec![0u8; Handshake::::send_key_len(None)]; + let responder = responder.send_key(None, &mut responder_key).unwrap(); + + // Exchange keys + let initiator = initiator + .receive_key(responder_key[..64].try_into().unwrap()) + .unwrap(); + let responder = responder + .receive_key(initiator_key[..64].try_into().unwrap()) + .unwrap(); + + // Both send version packets + let mut initiator_version = vec![0u8; Handshake::::send_version_len(None)]; + let initiator = initiator + .send_version(&mut initiator_version, None) + .unwrap(); + + let mut responder_version = vec![0u8; Handshake::::send_version_len(None)]; + let _responder = responder + .send_version(&mut responder_version, None) + .unwrap(); + + // Process the responder's garbage terminator to get to ReceivedGarbage state + let (handshake, _consumed) = match initiator.receive_garbage(&responder_version) { + Ok(GarbageResult::FoundGarbage { + handshake, + consumed_bytes, + }) => (handshake, consumed_bytes), + _ => panic!("Should find garbage terminator in valid version buffer"), + }; + + // Now fuzz the receive_version function with arbitrary packet data + let mut packet_data = data.to_vec(); + match handshake.receive_version(&mut packet_data) { + Ok(VersionResult::Complete { cipher: _ }) => { + // Successfully completed handshake + // This should only happen with valid encrypted version packet + } + Ok(VersionResult::Decoy(_)) => { + // Received a decoy packet + // This should happen when packet type indicates decoy + } + Err(_) => { + // Decryption or authentication failed + // This is the most common outcome with random data + } + } +}); diff --git a/protocol/src/handshake.rs b/protocol/src/handshake.rs index 20634ed..aa6f3f1 100644 --- a/protocol/src/handshake.rs +++ b/protocol/src/handshake.rs @@ -21,16 +21,9 @@ use rand::Rng; use crate::{ CipherSession, Error, OutboundCipher, PacketType, Role, SessionKeyMaterial, - NUM_ELLIGATOR_SWIFT_BYTES, NUM_GARBAGE_TERMINTOR_BYTES, NUM_LENGTH_BYTES, VERSION_CONTENT, + NUM_ELLIGATOR_SWIFT_BYTES, NUM_GARBAGE_TERMINTOR_BYTES, VERSION_CONTENT, }; -/// Initial buffer length hint to receive garbage, the garbage terminator, -/// and the version packet from the remote peer. This assumes no decoy packets -/// are sent (which is the default in Bitcoin Core) and no garbage bytes. -/// Calculated as: garbage_terminator (16 bytes) + encrypted_version_packet -/// where the version packet content is currently empty (0 bytes). -pub const NUM_INITIAL_BUFFER_BYTES_HINT: usize = - NUM_GARBAGE_TERMINTOR_BYTES + OutboundCipher::encryption_buffer_len(VERSION_CONTENT.len()); // Maximum number of garbage bytes before the terminator. const MAX_NUM_GARBAGE_BYTES: usize = 4095; @@ -49,7 +42,6 @@ pub struct Initialized { /// **Second state** after sending the local public key. pub struct SentKey<'a> { point: EcdhPoint, - bytes_written: usize, local_garbage: Option<&'a [u8]>, } @@ -60,26 +52,37 @@ pub struct ReceivedKey<'a> { local_garbage: Option<&'a [u8]>, } -/// **Fourth state** after sending the version packet. +/// **Fourth state** after sending the decoy and version packets. pub struct SentVersion { cipher: CipherSession, remote_garbage_terminator: [u8; NUM_GARBAGE_TERMINTOR_BYTES], - bytes_written: usize, - ciphertext_index: usize, - remote_garbage_authenticated: bool, } -/// Success variants for receive_version. -pub enum HandshakeAuthentication { - /// Successfully completed. - Complete { - cipher: CipherSession, - bytes_consumed: usize, +/// **Fifth state** after receiving the remote's garbage and garbage terminator. +pub struct ReceivedGarbage<'a> { + cipher: CipherSession, + remote_garbage: Option<&'a [u8]>, +} + +/// Success variants for reading remote garbage. +pub enum GarbageResult<'a> { + /// Successfully found garbage. + FoundGarbage { + handshake: Handshake>, + consumed_bytes: usize, }, - /// Need more data - returns handshake for caller to retry with more ciphertext. + /// No garbage terminator found, the input buffer needs to be extended. NeedMoreData(Handshake), } +/// Success variants for receiving remote version. +pub enum VersionResult<'a> { + /// Successfully completed handshake. + Complete { cipher: CipherSession }, + /// Packet was a decoy, read the next to see if version. + Decoy(Handshake>), +} + /// Handshake state-machine to establish the secret material in the communication channel. /// /// The handshake progresses through multiple states, enforcing the protocol sequence at compile time. @@ -98,7 +101,7 @@ pub struct Handshake { state: State, } -// Methods available in all states +// Methods available in all states. impl Handshake { /// Get the network this handshake is operating on. pub fn network(&self) -> Network { @@ -111,7 +114,6 @@ impl Handshake { } } -// Initialized state implementation impl Handshake { /// Initialize a V2 transport handshake with a remote peer. #[cfg(feature = "std")] @@ -189,12 +191,11 @@ impl Handshake { // Write local ellswift public key. output_buffer[..NUM_ELLIGATOR_SWIFT_BYTES] .copy_from_slice(&self.state.point.elligator_swift.to_array()); - let mut written = NUM_ELLIGATOR_SWIFT_BYTES; // Write garbage if provided. if let Some(g) = garbage { - output_buffer[written..written + g.len()].copy_from_slice(g); - written += g.len(); + output_buffer[NUM_ELLIGATOR_SWIFT_BYTES..NUM_ELLIGATOR_SWIFT_BYTES + g.len()] + .copy_from_slice(g); } Ok(Handshake { @@ -202,20 +203,13 @@ impl Handshake { role: self.role, state: SentKey { point: self.state.point, - bytes_written: written, local_garbage: garbage, }, }) } } -// SentKey state implementation impl<'a> Handshake> { - /// Get how many bytes were written by send_key(). - pub fn bytes_written(&self) -> usize { - self.state.bytes_written - } - /// Process the remote peer's public key and derive shared session secrets. /// /// This is the **second state transition** in the handshake process, moving from @@ -287,7 +281,6 @@ impl<'a> Handshake> { } } -// ReceivedKey state implementation impl<'a> Handshake> { /// Calculate how many bytes send_version() will write to buffer. pub fn send_version_len(decoys: Option<&[&[u8]]>) -> usize { @@ -372,7 +365,6 @@ impl<'a> Handshake> { } } - // Write version packet let version_packet_len = OutboundCipher::encryption_buffer_len(VERSION_CONTENT.len()); cipher.outbound().encrypt( &VERSION_CONTENT, @@ -380,7 +372,6 @@ impl<'a> Handshake> { PacketType::Genuine, aad, )?; - bytes_written += version_packet_len; Ok(Handshake { network: self.network, @@ -388,101 +379,105 @@ impl<'a> Handshake> { state: SentVersion { cipher, remote_garbage_terminator, - bytes_written, - ciphertext_index: 0, - remote_garbage_authenticated: false, }, }) } } -// SentVersion state implementation impl Handshake { - /// Get how many bytes were written by send_version(). - pub fn bytes_written(&self) -> usize { - self.state.bytes_written - } - - /// Authenticate remote peer's garbage, decoy packets, and version packet. + /// Process remote peer's garbage bytes and locate the garbage terminator. /// - /// This method is unique in the handshake process as it requires a **mutable** input buffer - /// to perform in-place decryption operations. The buffer contains everything after the 64 - /// byte public key received from the remote peer: optional garbage bytes, garbage terminator, - /// and encrypted packets (decoys and final version packet). - /// - /// The input buffer is mutable in the case because the caller generally doesn't care - /// about the decoy and version packets, and definitely doesn't want to deal with - /// allocating memory for them. + /// This is a critical step in the handshake process, transitioning from + /// `SentVersion` to `ReceivedGarbage` state. The method searches for the remote peer's + /// garbage terminator within the input buffer and separates garbage bytes from the + /// subsequent encrypted packet data. /// /// # Parameters /// - /// * `input_buffer` - **Mutable** buffer containing garbage + terminator + encrypted packets. - /// The buffer will be modified during in-place decryption operations. + /// * `input_buffer` - Buffer containing remote peer's garbage bytes followed by encrypted + /// packet data. The garbage terminator marks the boundary between these sections. /// /// # Returns /// - /// * `Complete { cipher, bytes_consumed }` - Handshake succeeded, secure session established. - /// * `NeedMoreData(handshake)` - Insufficient data, retry by extending the buffer. + /// * `Ok(GarbageResult::FoundGarbage)` - Successfully located garbage terminator. + /// Contains the next handshake state and number of bytes consumed from input buffer. + /// * `Ok(GarbageResult::NeedMoreData)` - Garbage terminator not found in current buffer. + /// More data needed to locate the terminator boundary. + /// + /// # Errors + /// + /// * `NoGarbageTerminator` - Input exceeds maximum garbage size without finding terminator. + /// Indicates protocol violation or potential attack. + /// + /// # Example + /// + /// ```rust + /// use bip324::{Handshake, GarbageResult, SentVersion}; + /// # use bip324::{Role, Network}; + /// # fn example() -> Result<(), Box> { + /// # let mut handshake = Handshake::new(Network::Bitcoin, Role::Initiator)?; + /// # // ... complete handshake to SentVersion state ... + /// # let handshake: Handshake = todo!(); + /// + /// let mut network_buffer = Vec::new(); + /// + /// loop { + /// // Read more network data... + /// // network_buffer.extend_from_slice(&new_data); + /// + /// match handshake.receive_garbage(&network_buffer)? { + /// GarbageResult::FoundGarbage { handshake: next_state, consumed_bytes } => { + /// // Success! Process remaining data for version packets + /// let remaining_data = &network_buffer[consumed_bytes..]; + /// // Continue with next_state.receive_version()... + /// break; + /// } + /// GarbageResult::NeedMoreData(handshake) => { + /// // Continue accumulating network data + /// continue; + /// } + /// } + /// } + /// # Ok(()) + /// # } /// ``` - pub fn receive_version( - mut self, - input_buffer: &mut [u8], - ) -> Result { - let (garbage, ciphertext) = match self.split_garbage(input_buffer) { - Ok(split) => split, - Err(Error::CiphertextTooSmall) => { - return Ok(HandshakeAuthentication::NeedMoreData(self)) - } - Err(e) => return Err(e), - }; - - let mut aad = if garbage.is_empty() || self.state.remote_garbage_authenticated { - None - } else { - Some(garbage) - }; - - // First packet authenticates remote garbage. - // Continue through decoys until we find version packet. - loop { - match self.decrypt_packet(&mut ciphertext[self.state.ciphertext_index..], aad) { - Ok((packet_type, bytes_consumed)) => { - if aad.is_some() { - aad = None; - self.state.remote_garbage_authenticated = true; - } - self.state.ciphertext_index += bytes_consumed; - if matches!(packet_type, PacketType::Genuine) { - break; - } - } - Err(Error::CiphertextTooSmall) => { - return Ok(HandshakeAuthentication::NeedMoreData(self)) - } - Err(e) => return Err(e), + pub fn receive_garbage<'a>(self, input_buffer: &'a [u8]) -> Result, Error> { + match self.split_garbage(input_buffer) { + Ok((garbage, _ciphertext)) => { + let consumed_bytes = garbage.len() + NUM_GARBAGE_TERMINTOR_BYTES; + let handshake = Handshake { + network: self.network, + role: self.role, + state: ReceivedGarbage { + cipher: self.state.cipher, + remote_garbage: if garbage.is_empty() { + None + } else { + Some(garbage) + }, + }, + }; + + Ok(GarbageResult::FoundGarbage { + handshake, + consumed_bytes, + }) } + Err(Error::CiphertextTooSmall) => Ok(GarbageResult::NeedMoreData(self)), + Err(e) => Err(e), } - - // Calculate total bytes consumed. - let bytes_consumed = - garbage.len() + NUM_GARBAGE_TERMINTOR_BYTES + self.state.ciphertext_index; - - Ok(HandshakeAuthentication::Complete { - cipher: self.state.cipher, - bytes_consumed, - }) } /// Split buffer on garbage terminator. - fn split_garbage<'b>(&self, buffer: &'b mut [u8]) -> Result<(&'b [u8], &'b mut [u8]), Error> { + fn split_garbage<'b>(&self, buffer: &'b [u8]) -> Result<(&'b [u8], &'b [u8]), Error> { let terminator = &self.state.remote_garbage_terminator; if let Some(index) = buffer .windows(terminator.len()) .position(|window| window == terminator) { - let (garbage, rest) = buffer.split_at_mut(index); - let ciphertext = &mut rest[terminator.len()..]; + let (garbage, rest) = buffer.split_at(index); + let ciphertext = &rest[terminator.len()..]; Ok((garbage, ciphertext)) } else if buffer.len() >= MAX_NUM_GARBAGE_BYTES + NUM_GARBAGE_TERMINTOR_BYTES { Err(Error::NoGarbageTerminator) @@ -490,33 +485,100 @@ impl Handshake { Err(Error::CiphertextTooSmall) } } +} - /// Decrypts in place, returns the packet type and number of bytes consumed from the ciphertext. - fn decrypt_packet( - &mut self, - ciphertext: &mut [u8], - aad: Option<&[u8]>, - ) -> Result<(PacketType, usize), Error> { - if ciphertext.len() < NUM_LENGTH_BYTES { - return Err(Error::CiphertextTooSmall); - } - - let packet_len = self.state.cipher.inbound().decrypt_packet_len( - ciphertext[..NUM_LENGTH_BYTES] - .try_into() - .expect("Checked length above"), - ); +impl<'a> Handshake> { + /// Decrypt the packet length from the encrypted length bytes. + pub fn decrypt_packet_len(&mut self, length_bytes: [u8; 3]) -> Result { + Ok(self.state.cipher.inbound().decrypt_packet_len(length_bytes)) + } - if ciphertext.len() < NUM_LENGTH_BYTES + packet_len { - return Err(Error::CiphertextTooSmall); + /// Decrypt and authenticate the next packet to complete the handshake. + /// + /// This is the **final state transition** in the handshake process, completing the + /// BIP-324 protocol by processing the remote peer's version packet. The method performs + /// in-place decryption of encrypted packet data and determines whether the handshake + /// is complete or if additional decoy packets need processing. + /// + /// # Unique Characteristics + /// + /// **Mutable Buffer Requirement**: Unlike other handshake methods, `receive_version()` + /// requires a mutable input buffer because it performs in-place decryption operations, + /// modifying ciphertext directly to produce plaintext for memory efficiency. + /// + /// # Parameters + /// + /// * `input_buffer` - **Mutable** buffer containing encrypted packet data (excluding + /// the 3-byte length prefix). The ciphertext will be overwritten with plaintext + /// during decryption. Buffer size must match the decrypted packet length. + /// + /// # Returns + /// + /// * `Ok(VersionResult::Complete { cipher })` - Handshake completed successfully. + /// The returned `CipherSession` is ready for secure message exchange. + /// * `Ok(VersionResult::Decoy(handshake))` - Packet was a decoy, continue processing + /// with the returned handshake state for the next packet. + /// + /// # Errors + /// + /// * `Decryption` - Packet authentication failed or ciphertext is malformed. + /// * `CiphertextTooSmall` - Ciphertext argument does not contain a whole packet + /// + /// # Example + /// + /// ```rust + /// use bip324::{Handshake, VersionResult, ReceivedGarbage, NUM_LENGTH_BYTES}; + /// # use bip324::{Role, Network}; + /// # fn example() -> Result<(), Box> { + /// # let mut handshake = Handshake::new(Network::Bitcoin, Role::Initiator)?; + /// # // ... complete handshake to ReceivedGarbage state ... + /// # let mut handshake: Handshake = todo!(); + /// # let encrypted_data: &[u8] = todo!(); + /// + /// let mut remaining_data = encrypted_data; + /// + /// // Process packets until version packet found + /// loop { + /// // Read packet length (first 3 bytes) + /// let packet_len = handshake.decrypt_packet_len( + /// remaining_data[..NUM_LENGTH_BYTES].try_into()? + /// )?; + /// + /// // Extract packet data (excluding length prefix) + /// let mut packet = remaining_data[NUM_LENGTH_BYTES..NUM_LENGTH_BYTES + packet_len].to_vec(); + /// remaining_data = &remaining_data[NUM_LENGTH_BYTES + packet_len..]; + /// + /// // Process the packet + /// match handshake.receive_version(&mut packet)? { + /// VersionResult::Complete { cipher } => { + /// // Handshake complete! Ready for secure messaging + /// break; + /// } + /// VersionResult::Decoy(next_handshake) => { + /// // Decoy packet processed, continue with next packet + /// handshake = next_handshake; + /// } + /// } + /// } + /// # Ok(()) + /// # } + /// ``` + pub fn receive_version(mut self, input_buffer: &mut [u8]) -> Result, Error> { + // Take the garbage on first call to ensure AAD is only used once. + let aad = self.state.remote_garbage.take(); + + let (packet_type, _) = self + .state + .cipher + .inbound() + .decrypt_in_place(input_buffer, aad)?; + + match packet_type { + PacketType::Genuine => Ok(VersionResult::Complete { + cipher: self.state.cipher, + }), + PacketType::Decoy => Ok(VersionResult::Decoy(self)), } - - let (packet_type, _) = self.state.cipher.inbound().decrypt_in_place( - &mut ciphertext[NUM_LENGTH_BYTES..NUM_LENGTH_BYTES + packet_len], - aad, - )?; - - Ok((packet_type, NUM_LENGTH_BYTES + packet_len)) } } @@ -524,8 +586,12 @@ impl Handshake { mod tests { use std::vec; + use crate::NUM_LENGTH_BYTES; + use super::*; + // Test that the handshake completes successfully with garbage and decoy packets + // from both parties. This is a comprehensive integration test of the full protocol. #[test] fn test_handshake() { let initiator_garbage = vec![1u8, 2u8, 3u8]; @@ -540,7 +606,6 @@ mod tests { let init_handshake = init_handshake .send_key(Some(&initiator_garbage), &mut init_buffer) .unwrap(); - assert_eq!(init_handshake.bytes_written(), 64 + initiator_garbage.len()); let resp_handshake = Handshake::::new(Network::Bitcoin, Role::Responder).unwrap(); @@ -551,16 +616,15 @@ mod tests { let resp_handshake = resp_handshake .send_key(Some(&responder_garbage), &mut resp_buffer) .unwrap(); - assert_eq!(resp_handshake.bytes_written(), 64 + responder_garbage.len()); // Initiator receives responder's key. let init_handshake = init_handshake - .receive_key(resp_buffer[..64].try_into().unwrap()) + .receive_key(resp_buffer[..NUM_ELLIGATOR_SWIFT_BYTES].try_into().unwrap()) .unwrap(); // Responder receives initiator's key. let resp_handshake = resp_handshake - .receive_key(init_buffer[..64].try_into().unwrap()) + .receive_key(init_buffer[..NUM_ELLIGATOR_SWIFT_BYTES].try_into().unwrap()) .unwrap(); // Create decoy packets for both sides. @@ -585,33 +649,91 @@ mod tests { .send_version(&mut resp_version_buffer, Some(&resp_decoys)) .unwrap(); - // Initiator receives responder's garbage, decoys, and version. - let mut full_resp_message = [&responder_garbage[..], &resp_version_buffer[..]].concat(); - match init_handshake - .receive_version(&mut full_resp_message) - .unwrap() - { - HandshakeAuthentication::Complete { bytes_consumed, .. } => { - assert_eq!(bytes_consumed, full_resp_message.len()); - } - HandshakeAuthentication::NeedMoreData(_) => panic!("Should have completed"), - } + // Initiator processes responder's response + let full_resp_message = [&responder_garbage[..], &resp_version_buffer[..]].concat(); + + // First, find the garbage terminator + let (mut init_handshake, consumed) = + match init_handshake.receive_garbage(&full_resp_message).unwrap() { + GarbageResult::FoundGarbage { + handshake, + consumed_bytes, + } => (handshake, consumed_bytes), + GarbageResult::NeedMoreData(_) => panic!("Should have found garbage terminator"), + }; + + // Process the encrypted packets (1 decoy + 1 version) + let mut remaining = &full_resp_message[consumed..]; + + // First packet is a decoy + let packet_len = init_handshake + .decrypt_packet_len(remaining[..NUM_LENGTH_BYTES].try_into().unwrap()) + .unwrap(); + let mut packet = remaining[NUM_LENGTH_BYTES..NUM_LENGTH_BYTES + packet_len].to_vec(); + remaining = &remaining[NUM_LENGTH_BYTES + packet_len..]; - // Responder receives initiator's garbage, decoys, and version. - let mut full_init_message = [&initiator_garbage[..], &init_version_buffer[..]].concat(); - match resp_handshake - .receive_version(&mut full_init_message) - .unwrap() - { - HandshakeAuthentication::Complete { bytes_consumed, .. } => { - assert_eq!(bytes_consumed, full_init_message.len()); - } - HandshakeAuthentication::NeedMoreData(_) => panic!("Should have completed"), + init_handshake = match init_handshake.receive_version(&mut packet).unwrap() { + VersionResult::Decoy(handshake) => handshake, + VersionResult::Complete { .. } => panic!("First packet should be decoy"), + }; + + // Second packet is the version + let packet_len = init_handshake + .decrypt_packet_len(remaining[..NUM_LENGTH_BYTES].try_into().unwrap()) + .unwrap(); + let mut packet = remaining[NUM_LENGTH_BYTES..NUM_LENGTH_BYTES + packet_len].to_vec(); + + match init_handshake.receive_version(&mut packet).unwrap() { + VersionResult::Complete { .. } => {} // Success! + VersionResult::Decoy(_) => panic!("Second packet should be version"), + }; + + // Responder processes initiator's response + let full_init_message = [&initiator_garbage[..], &init_version_buffer[..]].concat(); + + // First, find the garbage terminator + let (mut resp_handshake, consumed) = + match resp_handshake.receive_garbage(&full_init_message).unwrap() { + GarbageResult::FoundGarbage { + handshake, + consumed_bytes, + } => (handshake, consumed_bytes), + GarbageResult::NeedMoreData(_) => panic!("Should have found garbage terminator"), + }; + + // Process the encrypted packets (2 decoys + 1 version) + let mut remaining = &full_init_message[consumed..]; + + // First two packets are decoys + for i in 0..2 { + let packet_len = resp_handshake + .decrypt_packet_len(remaining[..NUM_LENGTH_BYTES].try_into().unwrap()) + .unwrap(); + let mut packet = remaining[NUM_LENGTH_BYTES..NUM_LENGTH_BYTES + packet_len].to_vec(); + remaining = &remaining[NUM_LENGTH_BYTES + packet_len..]; + + resp_handshake = match resp_handshake.receive_version(&mut packet).unwrap() { + VersionResult::Decoy(handshake) => handshake, + VersionResult::Complete { .. } => panic!("Packet {} should be decoy", i), + }; } + + // Third packet is the version + let packet_len = resp_handshake + .decrypt_packet_len(remaining[..NUM_LENGTH_BYTES].try_into().unwrap()) + .unwrap(); + let mut packet = remaining[NUM_LENGTH_BYTES..NUM_LENGTH_BYTES + packet_len].to_vec(); + + match resp_handshake.receive_version(&mut packet).unwrap() { + VersionResult::Complete { .. } => {} // Success! + VersionResult::Decoy(_) => panic!("Third packet should be version"), + }; } + // Test that send_key properly validates garbage length limits (max 4095 bytes) + // and buffer size requirements. #[test] - fn test_handshake_garbage_length_check() { + fn test_handshake_send_key() { // Test with valid garbage length let valid_garbage = vec![0u8; MAX_NUM_GARBAGE_BYTES]; let handshake = Handshake::::new(Network::Bitcoin, Role::Initiator).unwrap(); @@ -641,9 +763,11 @@ mod tests { assert!(result.is_ok()); } + // Test the NeedMoreData scenario where receive_garbage is called with partial data. + // The local peer doesn't know how much garbage the remote will send, so just needs + // to pull and do some buffer mangament. #[test] - fn test_handshake_receive_version_buffer() { - // Test the scenario where receive_version needs more data. + fn test_handshake_receive_garbage_buffer() { let init_handshake = Handshake::::new(Network::Bitcoin, Role::Initiator).unwrap(); let resp_handshake = @@ -656,10 +780,10 @@ mod tests { let resp_handshake = resp_handshake.send_key(None, &mut resp_buffer).unwrap(); let init_handshake = init_handshake - .receive_key(resp_buffer[..64].try_into().unwrap()) + .receive_key(resp_buffer[..NUM_ELLIGATOR_SWIFT_BYTES].try_into().unwrap()) .unwrap(); let resp_handshake = resp_handshake - .receive_key(init_buffer[..64].try_into().unwrap()) + .receive_key(init_buffer[..NUM_ELLIGATOR_SWIFT_BYTES].try_into().unwrap()) .unwrap(); let mut init_version_buffer = @@ -674,47 +798,61 @@ mod tests { .send_version(&mut resp_version_buffer, None) .unwrap(); - // Feed data in very small chunks to trigger NeedMoreData. + // Test streaming scenario with receive_garbage let partial_data_1 = &init_version_buffer[..1]; - let mut partial_data_1 = partial_data_1.to_vec(); - let returned_handshake = match resp_handshake.receive_version(&mut partial_data_1).unwrap() - { - HandshakeAuthentication::NeedMoreData(handshake) => handshake, - HandshakeAuthentication::Complete { .. } => { + let returned_handshake = match resp_handshake.receive_garbage(partial_data_1).unwrap() { + GarbageResult::NeedMoreData(handshake) => handshake, + GarbageResult::FoundGarbage { .. } => { panic!("Should have needed more data with 1 byte") } }; // Feed a bit more data - still probably not enough. let partial_data_2 = &init_version_buffer[..5]; - let mut partial_data_2 = partial_data_2.to_vec(); - let returned_handshake = match returned_handshake - .receive_version(&mut partial_data_2) - .unwrap() - { - HandshakeAuthentication::NeedMoreData(handshake) => handshake, - HandshakeAuthentication::Complete { .. } => { + let returned_handshake = match returned_handshake.receive_garbage(partial_data_2).unwrap() { + GarbageResult::NeedMoreData(handshake) => handshake, + GarbageResult::FoundGarbage { .. } => { panic!("Should have needed more data with 5 bytes") } }; - // Now provide the complete data. - let mut full_data = init_version_buffer.clone(); - match returned_handshake.receive_version(&mut full_data).unwrap() { - HandshakeAuthentication::Complete { bytes_consumed, .. } => { - assert_eq!(bytes_consumed, init_version_buffer.len()); - } - HandshakeAuthentication::NeedMoreData(_) => { - panic!("Should have completed with full data") + // Now provide enough data to find the garbage terminator. + // Since there's no garbage, the terminator should be at the beginning. + let (mut handshake, consumed) = match returned_handshake + .receive_garbage(&init_version_buffer) + .unwrap() + { + GarbageResult::FoundGarbage { + handshake, + consumed_bytes, + } => (handshake, consumed_bytes), + GarbageResult::NeedMoreData(_) => { + panic!("Should have found garbage terminator with full data") } - } + }; + + // Process the version packet + let remaining = &init_version_buffer[consumed..]; + let packet_len = handshake + .decrypt_packet_len(remaining[..NUM_LENGTH_BYTES].try_into().unwrap()) + .unwrap(); + let mut packet = remaining[NUM_LENGTH_BYTES..NUM_LENGTH_BYTES + packet_len].to_vec(); + + match handshake.receive_version(&mut packet).unwrap() { + VersionResult::Complete { .. } => {} // Success! + VersionResult::Decoy(_) => panic!("Should be version packet"), + }; } + // Test split_garbage error conditions. + // + // 1. NoGarbageTerminator - when buffer exceeds max size without finding terminator + // 2. CiphertextTooSmall - when buffer is too short to possibly contain terminator #[test] - fn test_handshake_no_garbage_terminator() { + fn test_handshake_split_garbage() { // Create a handshake and bring it to the SentVersion state to test split_garbage let handshake = Handshake::::new(Network::Bitcoin, Role::Initiator).unwrap(); - let mut buffer = vec![0u8; 64]; + let mut buffer = vec![0u8; NUM_ELLIGATOR_SWIFT_BYTES]; let handshake = handshake.send_key(None, &mut buffer).unwrap(); // Create a fake peer key to receive @@ -726,13 +864,39 @@ mod tests { let handshake = handshake.send_version(&mut version_buffer, None).unwrap(); // Test with a buffer that is too long (should fail to find terminator) - let mut test_buffer = vec![0; MAX_NUM_GARBAGE_BYTES + NUM_GARBAGE_TERMINTOR_BYTES]; - let result = handshake.split_garbage(&mut test_buffer); + let test_buffer = vec![0; MAX_NUM_GARBAGE_BYTES + NUM_GARBAGE_TERMINTOR_BYTES]; + let result = handshake.split_garbage(&test_buffer); assert!(matches!(result, Err(Error::NoGarbageTerminator))); // Test with a buffer that's just short of the required length - let mut short_buffer = vec![0; MAX_NUM_GARBAGE_BYTES + NUM_GARBAGE_TERMINTOR_BYTES - 1]; - let result = handshake.split_garbage(&mut short_buffer); + let short_buffer = vec![0; MAX_NUM_GARBAGE_BYTES + NUM_GARBAGE_TERMINTOR_BYTES - 1]; + let result = handshake.split_garbage(&short_buffer); assert!(matches!(result, Err(Error::CiphertextTooSmall))); } + + // Test that receive_key detects V1 protocol when peer's key starts with network magic. + #[test] + fn test_v1_protocol_detection() { + // Test that receive_key properly detects V1 protocol magic bytes + let handshake = Handshake::::new(Network::Bitcoin, Role::Initiator).unwrap(); + let mut buffer = vec![0u8; NUM_ELLIGATOR_SWIFT_BYTES]; + let handshake = handshake.send_key(None, &mut buffer).unwrap(); + + // Create a key that starts with Bitcoin mainnet magic bytes + let mut v1_key = [0u8; NUM_ELLIGATOR_SWIFT_BYTES]; + v1_key[..4].copy_from_slice(&Network::Bitcoin.magic().to_bytes()); + + let result = handshake.receive_key(v1_key); + assert!(matches!(result, Err(Error::V1Protocol))); + + // Test with different networks + let handshake = Handshake::::new(Network::Testnet, Role::Responder).unwrap(); + let handshake = handshake.send_key(None, &mut buffer).unwrap(); + + let mut v1_testnet_key = [0u8; NUM_ELLIGATOR_SWIFT_BYTES]; + v1_testnet_key[..4].copy_from_slice(&Network::Testnet.magic().to_bytes()); + + let result = handshake.receive_key(v1_testnet_key); + assert!(matches!(result, Err(Error::V1Protocol))); + } } diff --git a/protocol/src/io.rs b/protocol/src/io.rs index bafc90f..40d34ab 100644 --- a/protocol/src/io.rs +++ b/protocol/src/io.rs @@ -20,9 +20,9 @@ use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use crate::{ - handshake::{self, HandshakeAuthentication}, + handshake::{self, GarbageResult, VersionResult}, Error, Handshake, InboundCipher, OutboundCipher, PacketType, Role, NUM_ELLIGATOR_SWIFT_BYTES, - NUM_INITIAL_BUFFER_BYTES_HINT, + NUM_GARBAGE_TERMINTOR_BYTES, }; /// A decrypted BIP324 payload with its packet type. @@ -101,6 +101,23 @@ impl From for ProtocolError { } } +#[cfg(feature = "std")] +impl ProtocolError { + /// Create an EOF error that suggests retrying with V1 protocol. + /// + /// This is used when the remote peer closes the connection during handshake, + /// which often indicates they don't support the V2 protocol. + fn eof() -> Self { + ProtocolError::Io( + std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "Remote peer closed connection during handshake", + ), + ProtocolFailureSuggestion::RetryV1, + ) + } +} + #[cfg(feature = "std")] impl std::error::Error for ProtocolError { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { @@ -180,9 +197,7 @@ impl AsyncProtocol { let key_buffer_len = Handshake::::send_key_len(garbage); let mut key_buffer = vec![0u8; key_buffer_len]; let handshake = handshake.send_key(garbage, &mut key_buffer)?; - writer - .write_all(&key_buffer[..handshake.bytes_written()]) - .await?; + writer.write_all(&key_buffer).await?; writer.flush().await?; // Read remote's public key. @@ -194,62 +209,73 @@ impl AsyncProtocol { let version_buffer_len = Handshake::::send_version_len(decoys); let mut version_buffer = vec![0u8; version_buffer_len]; let handshake = handshake.send_version(&mut version_buffer, decoys)?; - writer - .write_all(&version_buffer[..handshake.bytes_written()]) - .await?; + writer.write_all(&version_buffer).await?; writer.flush().await?; - // Receive and authenticate remote garbage and version. - let mut remote_garbage_and_version_buffer = vec![0u8; NUM_INITIAL_BUFFER_BYTES_HINT]; - let mut bytes_read_so_far = 0; + // Receive and process garbage terminator + let mut garbage_buffer = vec![0u8; NUM_GARBAGE_TERMINTOR_BYTES]; + reader.read_exact(&mut garbage_buffer).await?; + let mut handshake = handshake; + let (mut handshake, garbage_bytes) = loop { + match handshake.receive_garbage(&garbage_buffer) { + Ok(GarbageResult::FoundGarbage { + handshake, + consumed_bytes, + }) => { + break (handshake, consumed_bytes); + } + Ok(GarbageResult::NeedMoreData(h)) => { + handshake = h; + // Use small chunks to avoid reading past garbage, decoys, and version. + let mut temp = vec![0u8; NUM_GARBAGE_TERMINTOR_BYTES]; + match reader.read(&mut temp).await { + Ok(0) => return Err(ProtocolError::eof()), + Ok(n) => { + garbage_buffer.extend_from_slice(&temp[..n]); + } + Err(e) => return Err(ProtocolError::from(e)), + } + } + Err(e) => return Err(ProtocolError::Internal(e)), + } + }; + // Process remaining bytes and read version packets. + let mut version_buffer = garbage_buffer[garbage_bytes..].to_vec(); loop { - match reader - .read(&mut remote_garbage_and_version_buffer[bytes_read_so_far..]) - .await - { - Ok(0) => { - // EOF - remote closed connection. - return Err(ProtocolError::Io( - std::io::Error::new( - std::io::ErrorKind::UnexpectedEof, - "Remote peer closed connection during handshake", - ), - ProtocolFailureSuggestion::RetryV1, - )); + // Decrypt packet length. + if version_buffer.len() < 3 { + let old_len = version_buffer.len(); + version_buffer.resize(3, 0); + reader.read_exact(&mut version_buffer[old_len..]).await?; + } + let packet_len = + handshake.decrypt_packet_len(version_buffer[..3].try_into().unwrap())?; + version_buffer.drain(..3); + + // Process packet. + if version_buffer.len() < packet_len { + let old_len = version_buffer.len(); + version_buffer.resize(packet_len, 0); + reader.read_exact(&mut version_buffer[old_len..]).await?; + } + match handshake.receive_version(&mut version_buffer) { + Ok(VersionResult::Complete { cipher }) => { + let (inbound_cipher, outbound_cipher) = cipher.into_split(); + return Ok(Self { + reader: AsyncProtocolReader { + inbound_cipher, + state: DecryptState::init_reading_length(), + }, + writer: AsyncProtocolWriter { outbound_cipher }, + }); } - Ok(bytes_read) => { - bytes_read_so_far += bytes_read; - - handshake = match handshake.receive_version( - &mut remote_garbage_and_version_buffer[..bytes_read_so_far], - ) { - Ok(HandshakeAuthentication::Complete { cipher, .. }) => { - let (inbound_cipher, outbound_cipher) = cipher.into_split(); - return Ok(Self { - reader: AsyncProtocolReader { - inbound_cipher, - state: DecryptState::init_reading_length(), - }, - writer: AsyncProtocolWriter { outbound_cipher }, - }); - } - Ok(HandshakeAuthentication::NeedMoreData(handshake)) => { - // Need more data - extend buffer for next read. - remote_garbage_and_version_buffer - .resize(remote_garbage_and_version_buffer.len() + 1024, 0); - handshake - } - Err(e) => return Err(ProtocolError::Internal(e)), - }; + Ok(VersionResult::Decoy(h)) => { + handshake = h; + version_buffer.drain(..packet_len); } - Err(e) => match e.kind() { - std::io::ErrorKind::WouldBlock | std::io::ErrorKind::Interrupted => { - continue; - } - _ => return Err(ProtocolError::Io(e, ProtocolFailureSuggestion::Abort)), - }, + Err(e) => return Err(ProtocolError::Internal(e)), } } } diff --git a/protocol/src/lib.rs b/protocol/src/lib.rs index 645b548..99224c7 100644 --- a/protocol/src/lib.rs +++ b/protocol/src/lib.rs @@ -40,8 +40,8 @@ use bitcoin_hashes::{hkdf, sha256, Hkdf}; pub use bitcoin::Network; pub use handshake::{ - Handshake, HandshakeAuthentication, Initialized, ReceivedKey, SentKey, SentVersion, - NUM_INITIAL_BUFFER_BYTES_HINT, + GarbageResult, Handshake, Initialized, ReceivedGarbage, ReceivedKey, SentKey, SentVersion, + VersionResult, }; // Re-exports from io module (async I/O types for backwards compatibility) #[cfg(any(feature = "futures", feature = "tokio"))] @@ -772,7 +772,7 @@ mod tests { } #[test] - fn test_authenticated_garbage() { + fn test_additional_authenticated_data() { let mut rng = rand::thread_rng(); let alice = SecretKey::from_str("61062ea5071d800bbfd59e2e8b53d47d194b095ae5a4df04936b49772ef0d4d7") @@ -818,232 +818,6 @@ mod tests { .unwrap(); } - #[test] - fn test_handshake_with_garbage_and_decoys() { - use crate::handshake::{HandshakeAuthentication, Initialized}; - - // Define the garbage and decoys that the initiator is sending to the responder. - let initiator_garbage = vec![1u8, 2u8, 3u8]; - let initiator_decoys: Vec<&[u8]> = vec![&[6u8, 7u8], &[8u8, 0u8]]; - - // Define the garbage and decoys that the responder is sending to the initiator. - let responder_garbage = vec![4u8, 5u8]; - let responder_decoys: Vec<&[u8]> = vec![&[10u8, 11u8], &[12u8], &[13u8, 14u8, 15u8]]; - - // Create initiator handshake - let init_handshake = - Handshake::::new(Network::Bitcoin, Role::Initiator).unwrap(); - - // Send initiator key + garbage - let mut init_key_buffer = vec![0u8; NUM_ELLIGATOR_SWIFT_BYTES + initiator_garbage.len()]; - let init_handshake = init_handshake - .send_key(Some(&initiator_garbage), &mut init_key_buffer) - .unwrap(); - - // Create responder handshake - let resp_handshake = - Handshake::::new(Network::Bitcoin, Role::Responder).unwrap(); - - // Send responder key + garbage - let mut resp_key_buffer = vec![0u8; NUM_ELLIGATOR_SWIFT_BYTES + responder_garbage.len()]; - let resp_handshake = resp_handshake - .send_key(Some(&responder_garbage), &mut resp_key_buffer) - .unwrap(); - - // Initiator receives responder's key - let init_handshake = init_handshake - .receive_key(resp_key_buffer[..64].try_into().unwrap()) - .unwrap(); - - // Responder receives initiator's key - let resp_handshake = resp_handshake - .receive_key(init_key_buffer[..64].try_into().unwrap()) - .unwrap(); - - // Initiator sends version - let mut init_version_buffer = vec![0u8; 1024]; - let init_handshake = init_handshake - .send_version(&mut init_version_buffer, Some(&initiator_decoys)) - .unwrap(); - let init_version_len = init_handshake.bytes_written(); - - // Responder sends version - let mut resp_version_buffer = vec![0u8; 1024]; - let resp_handshake = resp_handshake - .send_version(&mut resp_version_buffer, Some(&responder_decoys)) - .unwrap(); - let resp_version_len = resp_handshake.bytes_written(); - - // Initiator receives responder's version - let mut full_resp_message = [ - &responder_garbage[..], - &resp_version_buffer[..resp_version_len], - ] - .concat(); - let mut alice = match init_handshake - .receive_version(&mut full_resp_message) - .unwrap() - { - HandshakeAuthentication::Complete { cipher, .. } => cipher, - HandshakeAuthentication::NeedMoreData(_) => panic!("Should have completed"), - }; - - // Responder receives initiator's version - let mut full_init_message = [ - &initiator_garbage[..], - &init_version_buffer[..init_version_len], - ] - .concat(); - let mut bob = match resp_handshake - .receive_version(&mut full_init_message) - .unwrap() - { - HandshakeAuthentication::Complete { cipher, .. } => cipher, - HandshakeAuthentication::NeedMoreData(_) => panic!("Should have completed"), - }; - - let message = b"Hello world".to_vec(); - let packet_len = OutboundCipher::encryption_buffer_len(message.len()); - let mut encrypted_message_to_alice = vec![0u8; packet_len]; - bob.outbound() - .encrypt( - &message, - &mut encrypted_message_to_alice, - PacketType::Genuine, - None, - ) - .unwrap(); - - let bob_message_len = alice.inbound().decrypt_packet_len( - encrypted_message_to_alice[..NUM_LENGTH_BYTES] - .try_into() - .unwrap(), - ); - let mut dec = vec![0u8; InboundCipher::decryption_buffer_len(bob_message_len)]; - alice - .inbound() - .decrypt( - &encrypted_message_to_alice[NUM_LENGTH_BYTES..NUM_LENGTH_BYTES + bob_message_len], - &mut dec, - None, - ) - .unwrap(); - assert_eq!(message, dec[1..].to_vec()); // Skip header byte - - let message = b"g!".to_vec(); - let packet_len = OutboundCipher::encryption_buffer_len(message.len()); - let mut encrypted_message_to_bob = vec![0u8; packet_len]; - alice - .outbound() - .encrypt( - &message, - &mut encrypted_message_to_bob, - PacketType::Genuine, - None, - ) - .unwrap(); - - let alice_message_len = bob.inbound().decrypt_packet_len( - encrypted_message_to_bob[..NUM_LENGTH_BYTES] - .try_into() - .unwrap(), - ); - let mut dec = vec![0u8; InboundCipher::decryption_buffer_len(alice_message_len)]; - bob.inbound() - .decrypt( - &encrypted_message_to_bob[NUM_LENGTH_BYTES..NUM_LENGTH_BYTES + alice_message_len], - &mut dec, - None, - ) - .unwrap(); - assert_eq!(message, dec[1..].to_vec()); // Skip header byte - } - - #[test] - fn test_partial_decodings() { - use crate::handshake::{HandshakeAuthentication, Initialized}; - let mut rng = rand::thread_rng(); - - // Create initiator handshake - let init_handshake = - Handshake::::new(Network::Bitcoin, Role::Initiator).unwrap(); - - // Send initiator key (no garbage) - let mut init_key_buffer = vec![0u8; 64]; - let init_handshake = init_handshake.send_key(None, &mut init_key_buffer).unwrap(); - - // Create responder handshake - let resp_handshake = - Handshake::::new(Network::Bitcoin, Role::Responder).unwrap(); - - // Send responder key (no garbage) - let mut resp_key_buffer = vec![0u8; 64]; - let resp_handshake = resp_handshake.send_key(None, &mut resp_key_buffer).unwrap(); - - // Initiator receives responder's key - let init_handshake = init_handshake - .receive_key(resp_key_buffer.try_into().unwrap()) - .unwrap(); - - // Responder receives initiator's key - let resp_handshake = resp_handshake - .receive_key(init_key_buffer.try_into().unwrap()) - .unwrap(); - - // Initiator sends version - let mut init_version_buffer = vec![0u8; 36]; - let init_handshake = init_handshake - .send_version(&mut init_version_buffer, None) - .unwrap(); - - // Responder sends version - let mut resp_version_buffer = vec![0u8; 36]; - let resp_handshake = resp_handshake - .send_version(&mut resp_version_buffer, None) - .unwrap(); - - // Initiator receives responder's version - let mut alice = match init_handshake - .receive_version(&mut resp_version_buffer) - .unwrap() - { - HandshakeAuthentication::Complete { cipher, .. } => cipher, - HandshakeAuthentication::NeedMoreData(_) => panic!("Should have completed"), - }; - - // Responder receives initiator's version - let mut bob = match resp_handshake - .receive_version(&mut init_version_buffer) - .unwrap() - { - HandshakeAuthentication::Complete { cipher, .. } => cipher, - HandshakeAuthentication::NeedMoreData(_) => panic!("Should have completed"), - }; - - let mut message_to_bob = Vec::new(); - let message = gen_garbage(420, &mut rng); - let packet_len = OutboundCipher::encryption_buffer_len(message.len()); - let mut enc_packet = vec![0u8; packet_len]; - alice - .outbound() - .encrypt(&message, &mut enc_packet, PacketType::Genuine, None) - .unwrap(); - message_to_bob.extend(&enc_packet); - - let alice_message_len = bob - .inbound() - .decrypt_packet_len(message_to_bob[..3].try_into().unwrap()); - let mut contents = vec![0u8; InboundCipher::decryption_buffer_len(alice_message_len)]; - bob.inbound() - .decrypt( - &message_to_bob[3..3 + alice_message_len], - &mut contents, - None, - ) - .unwrap(); - assert_eq!(message, contents[1..].to_vec()); // Skip header byte - } - // The rest are sourced from [the BIP324 test vectors](https://github.com/bitcoin/bips/blob/master/bip-0324/packet_encoding_test_vectors.csv). #[test] diff --git a/protocol/tests/round_trips.rs b/protocol/tests/round_trips.rs index 36340ff..0fdd7d0 100644 --- a/protocol/tests/round_trips.rs +++ b/protocol/tests/round_trips.rs @@ -5,7 +5,10 @@ const PORT: u16 = 18444; #[test] #[cfg(feature = "std")] fn hello_world_happy_path() { - use bip324::{Handshake, HandshakeAuthentication, Initialized, PacketType, ReceivedKey, Role}; + use bip324::{ + GarbageResult, Handshake, Initialized, PacketType, ReceivedKey, Role, VersionResult, + NUM_LENGTH_BYTES, + }; use bitcoin::Network; // Create initiator handshake @@ -44,22 +47,50 @@ fn hello_world_happy_path() { .send_version(&mut resp_version_buffer, None) .unwrap(); - // Initiator receives responder's version - let mut alice = match init_handshake - .receive_version(&mut resp_version_buffer) + // Initiator receives responder's garbage and version + let (mut init_handshake, consumed) = match init_handshake + .receive_garbage(&resp_version_buffer) .unwrap() { - HandshakeAuthentication::Complete { cipher, .. } => cipher, - HandshakeAuthentication::NeedMoreData(_) => panic!("Should have completed"), + GarbageResult::FoundGarbage { + handshake, + consumed_bytes, + } => (handshake, consumed_bytes), + GarbageResult::NeedMoreData(_) => panic!("Should have found garbage"), + }; + + // Process the version packet properly + let remaining = &resp_version_buffer[consumed..]; + let packet_len = init_handshake + .decrypt_packet_len(remaining[..NUM_LENGTH_BYTES].try_into().unwrap()) + .unwrap(); + let mut version_packet = remaining[NUM_LENGTH_BYTES..NUM_LENGTH_BYTES + packet_len].to_vec(); + let mut alice = match init_handshake.receive_version(&mut version_packet).unwrap() { + VersionResult::Complete { cipher } => cipher, + VersionResult::Decoy(_) => panic!("Should have completed"), }; - // Responder receives initiator's version - let mut bob = match resp_handshake - .receive_version(&mut init_version_buffer) + // Responder receives initiator's garbage and version + let (mut resp_handshake, consumed) = match resp_handshake + .receive_garbage(&init_version_buffer) .unwrap() { - HandshakeAuthentication::Complete { cipher, .. } => cipher, - HandshakeAuthentication::NeedMoreData(_) => panic!("Should have completed"), + GarbageResult::FoundGarbage { + handshake, + consumed_bytes, + } => (handshake, consumed_bytes), + GarbageResult::NeedMoreData(_) => panic!("Should have found garbage"), + }; + + // Process the version packet properly + let remaining = &init_version_buffer[consumed..]; + let packet_len = resp_handshake + .decrypt_packet_len(remaining[..NUM_LENGTH_BYTES].try_into().unwrap()) + .unwrap(); + let mut version_packet = remaining[NUM_LENGTH_BYTES..NUM_LENGTH_BYTES + packet_len].to_vec(); + let mut bob = match resp_handshake.receive_version(&mut version_packet).unwrap() { + VersionResult::Complete { cipher } => cipher, + VersionResult::Decoy(_) => panic!("Should have completed"), }; // Alice and Bob can freely exchange encrypted messages using the packet handler returned by each handshake. @@ -75,15 +106,17 @@ fn hello_world_happy_path() { ) .unwrap(); - let alice_message_len = alice - .inbound() - .decrypt_packet_len(encrypted_message_to_alice[..3].try_into().unwrap()); + let alice_message_len = alice.inbound().decrypt_packet_len( + encrypted_message_to_alice[..NUM_LENGTH_BYTES] + .try_into() + .unwrap(), + ); let mut decrypted_message = vec![0u8; bip324::InboundCipher::decryption_buffer_len(alice_message_len)]; alice .inbound() .decrypt( - &encrypted_message_to_alice[3..], + &encrypted_message_to_alice[NUM_LENGTH_BYTES..], &mut decrypted_message, None, ) @@ -103,13 +136,19 @@ fn hello_world_happy_path() { ) .unwrap(); - let bob_message_len = bob - .inbound() - .decrypt_packet_len(encrypted_message_to_bob[..3].try_into().unwrap()); + let bob_message_len = bob.inbound().decrypt_packet_len( + encrypted_message_to_bob[..NUM_LENGTH_BYTES] + .try_into() + .unwrap(), + ); let mut decrypted_message = vec![0u8; bip324::InboundCipher::decryption_buffer_len(bob_message_len)]; bob.inbound() - .decrypt(&encrypted_message_to_bob[3..], &mut decrypted_message, None) + .decrypt( + &encrypted_message_to_bob[NUM_LENGTH_BYTES..], + &mut decrypted_message, + None, + ) .unwrap(); assert_eq!(message, decrypted_message[1..].to_vec()); // Skip header byte } @@ -125,7 +164,8 @@ fn regtest_handshake() { use bip324::{ serde::{deserialize, serialize, NetworkMessage}, - Handshake, HandshakeAuthentication, Initialized, PacketType, ReceivedKey, + GarbageResult, Handshake, Initialized, PacketType, ReceivedKey, VersionResult, + NUM_LENGTH_BYTES, }; use bitcoin::p2p::{message_network::VersionMessage, Address, ServiceFlags}; let bitcoind = regtest_process(TransportVersion::V2); @@ -165,15 +205,52 @@ fn regtest_handshake() { let mut max_response = [0; 4096]; println!("Reading the response buffer"); let size = stream.read(&mut max_response).unwrap(); - let response = &mut max_response[..size]; + let response = &max_response[..size]; println!("Authenticating the handshake"); - let cipher_session = match handshake.receive_version(response).unwrap() { - HandshakeAuthentication::Complete { cipher, .. } => { - println!("Finalizing the handshake"); - cipher + // First receive garbage + let (mut handshake, consumed) = match handshake.receive_garbage(response).unwrap() { + GarbageResult::FoundGarbage { + handshake, + consumed_bytes, + } => { + println!("Found garbage terminator after {consumed_bytes} bytes"); + (handshake, consumed_bytes) + } + GarbageResult::NeedMoreData(_) => panic!("Should have found garbage"), + }; + + // Then receive version - properly handle packet length and potential decoys. + let mut remaining = &response[consumed..]; + let cipher_session = loop { + // Check if we have enough data for packet length + if remaining.len() < NUM_LENGTH_BYTES { + panic!("Not enough data for packet length"); + } + + let packet_len = handshake + .decrypt_packet_len(remaining[..NUM_LENGTH_BYTES].try_into().unwrap()) + .unwrap(); + + if remaining.len() < NUM_LENGTH_BYTES + packet_len { + panic!("Not enough data for full packet"); + } + + let mut version_packet = + remaining[NUM_LENGTH_BYTES..NUM_LENGTH_BYTES + packet_len].to_vec(); + + match handshake.receive_version(&mut version_packet).unwrap() { + VersionResult::Complete { cipher } => { + println!("Finalizing the handshake"); + break cipher; + } + VersionResult::Decoy(h) => { + println!("Received decoy packet, continuing..."); + handshake = h; + remaining = &remaining[NUM_LENGTH_BYTES + packet_len..]; + continue; + } } - HandshakeAuthentication::NeedMoreData(_) => panic!("Should have completed"), }; let (mut decrypter, mut encrypter) = cipher_session.into_split(); @@ -203,7 +280,7 @@ fn regtest_handshake() { println!("Serializing and writing version message"); stream.write_all(&packet).unwrap(); println!("Reading the response length buffer"); - let mut response_len = [0; 3]; + let mut response_len = [0; NUM_LENGTH_BYTES]; stream.read_exact(&mut response_len).unwrap(); let message_len = decrypter.decrypt_packet_len(response_len); let mut response_message = vec![0; message_len];