diff --git a/justfile b/justfile index ae02f82..fe2c160 100644 --- a/justfile +++ b/justfile @@ -75,8 +75,8 @@ _default: bench: cargo +{{NIGHTLY_TOOLCHAIN}} bench --package bip324 --bench cipher_session -# Run fuzz test: handshake. -@fuzz target="handshake" time="60": +# Run fuzz test: receive_key, receive_garbage, receive_version. +@fuzz target="receive_garbage" time="60": cargo install cargo-fuzz@0.12.0 cd protocol && cargo +{{NIGHTLY_TOOLCHAIN}} fuzz run {{target}} -- -max_total_time={{time}} diff --git a/protocol/benches/cipher_session.rs b/protocol/benches/cipher_session.rs index f206683..f0f3be3 100644 --- a/protocol/benches/cipher_session.rs +++ b/protocol/benches/cipher_session.rs @@ -1,60 +1,99 @@ +// SPDX-License-Identifier: CC0-1.0 + #![feature(test)] extern crate test; -use bip324::{CipherSession, Handshake, InboundCipher, Network, OutboundCipher, PacketType, Role}; +use bip324::{ + CipherSession, GarbageResult, Handshake, InboundCipher, Initialized, Network, OutboundCipher, + PacketType, ReceivedKey, Role, VersionResult, NUM_LENGTH_BYTES, +}; use test::{black_box, Bencher}; fn create_cipher_session_pair() -> (CipherSession, CipherSession) { - // Create a proper handshake between Alice and Bob. - let mut alice_init_buffer = vec![0u8; 64]; - let mut alice_handshake = Handshake::new( - Network::Bitcoin, - Role::Initiator, - None, - &mut alice_init_buffer, - ) - .unwrap(); - - let mut bob_init_buffer = vec![0u8; 100]; - let mut bob_handshake = Handshake::new( - Network::Bitcoin, - Role::Responder, - None, - &mut bob_init_buffer, - ) - .unwrap(); - - // Bob completes materials with Alice's key. - bob_handshake - .complete_materials( - alice_init_buffer[..64].try_into().unwrap(), - &mut bob_init_buffer[64..], - None, - ) + // Send Alice's key. + let alice_handshake = Handshake::::new(Network::Bitcoin, Role::Initiator).unwrap(); + let mut alice_key_buffer = vec![0u8; Handshake::::send_key_len(None)]; + let alice_handshake = alice_handshake + .send_key(None, &mut alice_key_buffer) + .unwrap(); + + // Send Bob's key + let bob_handshake = Handshake::::new(Network::Bitcoin, Role::Responder).unwrap(); + let mut bob_key_buffer = vec![0u8; Handshake::::send_key_len(None)]; + let bob_handshake = bob_handshake.send_key(None, &mut bob_key_buffer).unwrap(); + + // Alice receives Bob's key. + let alice_handshake = alice_handshake + .receive_key(bob_key_buffer.try_into().unwrap()) + .unwrap(); + + // Bob receives Alice's key. + let bob_handshake = bob_handshake + .receive_key(alice_key_buffer.try_into().unwrap()) + .unwrap(); + + // Alice sends version. + let mut alice_version_buffer = vec![0u8; Handshake::::send_version_len(None)]; + let alice_handshake = alice_handshake + .send_version(&mut alice_version_buffer, None) .unwrap(); - // Alice completes materials with Bob's key. - let mut alice_response_buffer = vec![0u8; 36]; - alice_handshake - .complete_materials( - bob_init_buffer[..64].try_into().unwrap(), - &mut alice_response_buffer, - None, - ) + // Bob sends version. + let mut bob_version_buffer = vec![0u8; Handshake::::send_version_len(None)]; + let bob_handshake = bob_handshake + .send_version(&mut bob_version_buffer, None) .unwrap(); - // Authenticate. - let mut packet_buffer = vec![0u8; 4096]; - alice_handshake - .authenticate_garbage_and_version(&bob_init_buffer[64..], &mut packet_buffer) + // Alice receives Bob's version. + // First handle Bob's garbage terminator + let (mut alice_handshake, consumed) = match alice_handshake + .receive_garbage(&bob_version_buffer) + .unwrap() + { + GarbageResult::FoundGarbage { + handshake, + consumed_bytes, + } => (handshake, consumed_bytes), + GarbageResult::NeedMoreData(_) => panic!("Should have found garbage terminator"), + }; + + // Process Bob's version packet + let remaining = &bob_version_buffer[consumed..]; + let packet_len = alice_handshake + .decrypt_packet_len(remaining[..NUM_LENGTH_BYTES].try_into().unwrap()) .unwrap(); - bob_handshake - .authenticate_garbage_and_version(&alice_response_buffer, &mut packet_buffer) + let mut packet = remaining[NUM_LENGTH_BYTES..NUM_LENGTH_BYTES + packet_len].to_vec(); + + let alice = match alice_handshake.receive_version(&mut packet).unwrap() { + VersionResult::Complete { cipher } => cipher, + VersionResult::Decoy(_) => panic!("Should have completed"), + }; + + // Bob receives Alice's version. + // First handle Alice's garbage terminator + let (mut bob_handshake, consumed) = match bob_handshake + .receive_garbage(&alice_version_buffer) + .unwrap() + { + GarbageResult::FoundGarbage { + handshake, + consumed_bytes, + } => (handshake, consumed_bytes), + GarbageResult::NeedMoreData(_) => panic!("Should have found garbage terminator"), + }; + + // Process Alice's version packet + let remaining = &alice_version_buffer[consumed..]; + let packet_len = bob_handshake + .decrypt_packet_len(remaining[..NUM_LENGTH_BYTES].try_into().unwrap()) .unwrap(); + let mut packet = remaining[NUM_LENGTH_BYTES..NUM_LENGTH_BYTES + packet_len].to_vec(); - let alice = alice_handshake.finalize().unwrap(); - let bob = bob_handshake.finalize().unwrap(); + let bob = match bob_handshake.receive_version(&mut packet).unwrap() { + VersionResult::Complete { cipher } => cipher, + VersionResult::Decoy(_) => panic!("Should have completed"), + }; (alice, bob) } @@ -78,16 +117,16 @@ fn bench_round_trip_small_packet(b: &mut Bencher) { ) .unwrap(); - // Decrypt the length from first 3 bytes (real-world step). - let packet_length = bob - .inbound() - .decrypt_packet_len(black_box(encrypted[0..3].try_into().unwrap())); + // Decrypt the length from first NUM_LENGTH_BYTES bytes (real-world step). + let packet_length = bob.inbound().decrypt_packet_len(black_box( + encrypted[0..NUM_LENGTH_BYTES].try_into().unwrap(), + )); // Decrypt the payload using the decrypted length. let mut decrypted = vec![0u8; InboundCipher::decryption_buffer_len(packet_length)]; bob.inbound() .decrypt( - black_box(&encrypted[3..3 + packet_length]), + black_box(&encrypted[NUM_LENGTH_BYTES..NUM_LENGTH_BYTES + packet_length]), &mut decrypted, None, ) @@ -117,16 +156,16 @@ fn bench_round_trip_large_packet(b: &mut Bencher) { ) .unwrap(); - // Decrypt the length from first 3 bytes (real-world step). - let packet_length = bob - .inbound() - .decrypt_packet_len(black_box(encrypted[0..3].try_into().unwrap())); + // Decrypt the length from first NUM_LENGTH_BYTES bytes (real-world step). + let packet_length = bob.inbound().decrypt_packet_len(black_box( + encrypted[0..NUM_LENGTH_BYTES].try_into().unwrap(), + )); // Decrypt the payload using the decrypted length. let mut decrypted = vec![0u8; InboundCipher::decryption_buffer_len(packet_length)]; bob.inbound() .decrypt( - black_box(&encrypted[3..3 + packet_length]), + black_box(&encrypted[NUM_LENGTH_BYTES..NUM_LENGTH_BYTES + packet_length]), &mut decrypted, None, ) diff --git a/protocol/fuzz/Cargo.toml b/protocol/fuzz/Cargo.toml index f0531e5..e1d6c35 100644 --- a/protocol/fuzz/Cargo.toml +++ b/protocol/fuzz/Cargo.toml @@ -12,8 +12,22 @@ libfuzzer-sys = "0.4" bip324 = { path = ".." } [[bin]] -name = "handshake" -path = "fuzz_targets/handshake.rs" +name = "receive_garbage" +path = "fuzz_targets/receive_garbage.rs" +test = false +doc = false +bench = false + +[[bin]] +name = "receive_version" +path = "fuzz_targets/receive_version.rs" +test = false +doc = false +bench = false + +[[bin]] +name = "receive_key" +path = "fuzz_targets/receive_key.rs" test = false doc = false bench = false diff --git a/protocol/fuzz/fuzz_targets/handshake.rs b/protocol/fuzz/fuzz_targets/handshake.rs deleted file mode 100644 index bd8c5c9..0000000 --- a/protocol/fuzz/fuzz_targets/handshake.rs +++ /dev/null @@ -1,58 +0,0 @@ -#![no_main] -use bip324::{Handshake, Network, Role, NUM_INITIAL_HANDSHAKE_BUFFER_BYTES}; -use libfuzzer_sys::fuzz_target; - -fuzz_target!(|data: &[u8]| { - // Skip if data is too small for an interesting test. - if data.len() < 64 { - return; - } - - let mut initiator_pubkey = [0u8; 64]; - let mut handshake = Handshake::new( - Network::Bitcoin, - Role::Initiator, - None, - &mut initiator_pubkey, - ) - .unwrap(); - - let mut responder_pubkey = [0u8; 64]; - let _responder_handshake = Handshake::new( - Network::Bitcoin, - Role::Responder, - None, - &mut responder_pubkey, - ) - .unwrap(); - - // Create a mutation of the responder's bytes. - let mut garbage_and_version = [0u8; 36]; - let copy_len = std::cmp::min(data.len(), garbage_and_version.len()); - garbage_and_version[..copy_len].copy_from_slice(&data[..copy_len]); - - // Create mutation of the responder's public key. - // The key is either completely random or slightly tweaked. - let mut fuzzed_responder_pubkey = [0u8; 64]; - if data.len() >= 128 { - fuzzed_responder_pubkey.copy_from_slice(&data[64..128]); - } else { - fuzzed_responder_pubkey.copy_from_slice(&responder_pubkey); - for (i, b) in data - .iter() - .enumerate() - .take(fuzzed_responder_pubkey.len()) - .skip(copy_len) - { - fuzzed_responder_pubkey[i % 64] ^= b; // XOR to make controlled changes. - } - } - - // Try to complete the materials and authenticate with the fuzzed key and data. - // Exercising malformed public key handling. - let _ = handshake.complete_materials(fuzzed_responder_pubkey, &mut garbage_and_version, None); - // Check how a broken handshake is handled. - let mut packet_buffer = vec![0u8; NUM_INITIAL_HANDSHAKE_BUFFER_BYTES]; // Initial buffer for decoy and version packets - let _ = handshake.authenticate_garbage_and_version(&garbage_and_version, &mut packet_buffer); - let _ = handshake.finalize(); -}); diff --git a/protocol/fuzz/fuzz_targets/receive_garbage.rs b/protocol/fuzz/fuzz_targets/receive_garbage.rs new file mode 100644 index 0000000..5e6a9ae --- /dev/null +++ b/protocol/fuzz/fuzz_targets/receive_garbage.rs @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: CC0-1.0 + +//! Fuzz test for the receive_garbage function. +//! +//! This focused test fuzzes only the garbage terminator detection logic, +//! which is more effective than trying to fuzz the entire handshake. + +#![no_main] +use bip324::{GarbageResult, Handshake, Initialized, Network, ReceivedKey, Role}; +use libfuzzer_sys::fuzz_target; + +fuzz_target!(|data: &[u8]| { + // Set up a valid handshake in the SentVersion state + let initiator = Handshake::::new(Network::Bitcoin, Role::Initiator).unwrap(); + let mut initiator_key = vec![0u8; Handshake::::send_key_len(None)]; + let initiator = initiator.send_key(None, &mut initiator_key).unwrap(); + + let responder = Handshake::::new(Network::Bitcoin, Role::Responder).unwrap(); + let mut responder_key = vec![0u8; Handshake::::send_key_len(None)]; + let responder = responder.send_key(None, &mut responder_key).unwrap(); + + // Exchange keys using real keys to get valid ECDH shared secrets + let initiator = initiator + .receive_key(responder_key[..64].try_into().unwrap()) + .unwrap(); + let _responder = responder + .receive_key(initiator_key[..64].try_into().unwrap()) + .unwrap(); + + // Send version to reach SentVersion state + let mut initiator_version = vec![0u8; Handshake::::send_version_len(None)]; + let initiator = initiator + .send_version(&mut initiator_version, None) + .unwrap(); + + // Now fuzz the receive_garbage function with arbitrary data + match initiator.receive_garbage(data) { + Ok(GarbageResult::FoundGarbage { + handshake: _, + consumed_bytes, + }) => { + // Successfully found garbage terminator + // Verify consumed_bytes is reasonable + assert!(consumed_bytes <= data.len()); + assert!(consumed_bytes >= 16); // At least the terminator size + + // The garbage should be everything before the terminator + let garbage_len = consumed_bytes - 16; + assert!(garbage_len <= 4095); // Max garbage size + } + Ok(GarbageResult::NeedMoreData(_)) => { + // Need more data - valid outcome for short inputs + // This should happen when: + // 1. Buffer is too short to contain terminator + // 2. Buffer doesn't contain the terminator yet + } + Err(_) => { + // Error parsing garbage - valid outcome + // This should happen when: + // 1. No terminator found within max garbage size + } + } +}); diff --git a/protocol/fuzz/fuzz_targets/receive_key.rs b/protocol/fuzz/fuzz_targets/receive_key.rs new file mode 100644 index 0000000..9301c23 --- /dev/null +++ b/protocol/fuzz/fuzz_targets/receive_key.rs @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: CC0-1.0 + +//! Fuzz test for the receive_key function. +//! +//! This focused test fuzzes the elliptic curve point validation and ECDH logic. + +#![no_main] +use bip324::{Handshake, Initialized, Network, Role}; +use libfuzzer_sys::fuzz_target; + +fuzz_target!(|data: &[u8]| { + // Skip if data is not exactly 64 bytes + if data.len() != 64 { + return; + } + + // Set up a handshake in the SentKey state + let handshake = Handshake::::new(Network::Bitcoin, Role::Initiator).unwrap(); + let mut key_buffer = vec![0u8; Handshake::::send_key_len(None)]; + let handshake = handshake.send_key(None, &mut key_buffer).unwrap(); + + // Fuzz the receive_key function with arbitrary 64-byte data + let mut key_bytes = [0u8; 64]; + key_bytes.copy_from_slice(data); + + match handshake.receive_key(key_bytes) { + Ok(_handshake) => { + // Successfully processed the key + // This means: + // 1. The 64 bytes represent a valid ElligatorSwift encoding + // 2. The ECDH operation succeeded + // 3. The key derivation worked + // 4. It's not the V1 protocol magic bytes + } + Err(_) => { + // Failed to process the key + // This could be: + // 1. Invalid ElligatorSwift encoding + // 2. V1 protocol detected (first 4 bytes match network magic) + // 3. ECDH or key derivation failure + } + } +}); diff --git a/protocol/fuzz/fuzz_targets/receive_version.rs b/protocol/fuzz/fuzz_targets/receive_version.rs new file mode 100644 index 0000000..bbb23fa --- /dev/null +++ b/protocol/fuzz/fuzz_targets/receive_version.rs @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: CC0-1.0 + +//! Fuzz test for the receive_version function. +//! +//! This focused test fuzzes only the version packet decryption logic. + +#![no_main] +use bip324::{GarbageResult, Handshake, Initialized, Network, ReceivedKey, Role, VersionResult}; +use libfuzzer_sys::fuzz_target; + +fuzz_target!(|data: &[u8]| { + // Skip if data is too small + if data.is_empty() { + return; + } + + // Set up a valid handshake in the ReceivedGarbage state + let initiator = Handshake::::new(Network::Bitcoin, Role::Initiator).unwrap(); + let mut initiator_key = vec![0u8; Handshake::::send_key_len(None)]; + let initiator = initiator.send_key(None, &mut initiator_key).unwrap(); + + let responder = Handshake::::new(Network::Bitcoin, Role::Responder).unwrap(); + let mut responder_key = vec![0u8; Handshake::::send_key_len(None)]; + let responder = responder.send_key(None, &mut responder_key).unwrap(); + + // Exchange keys + let initiator = initiator + .receive_key(responder_key[..64].try_into().unwrap()) + .unwrap(); + let responder = responder + .receive_key(initiator_key[..64].try_into().unwrap()) + .unwrap(); + + // Both send version packets + let mut initiator_version = vec![0u8; Handshake::::send_version_len(None)]; + let initiator = initiator + .send_version(&mut initiator_version, None) + .unwrap(); + + let mut responder_version = vec![0u8; Handshake::::send_version_len(None)]; + let _responder = responder + .send_version(&mut responder_version, None) + .unwrap(); + + // Process the responder's garbage terminator to get to ReceivedGarbage state + let (handshake, _consumed) = match initiator.receive_garbage(&responder_version) { + Ok(GarbageResult::FoundGarbage { + handshake, + consumed_bytes, + }) => (handshake, consumed_bytes), + _ => panic!("Should find garbage terminator in valid version buffer"), + }; + + // Now fuzz the receive_version function with arbitrary packet data + let mut packet_data = data.to_vec(); + match handshake.receive_version(&mut packet_data) { + Ok(VersionResult::Complete { cipher: _ }) => { + // Successfully completed handshake + // This should only happen with valid encrypted version packet + } + Ok(VersionResult::Decoy(_)) => { + // Received a decoy packet + // This should happen when packet type indicates decoy + } + Err(_) => { + // Decryption or authentication failed + // This is the most common outcome with random data + } + } +}); diff --git a/protocol/src/handshake.rs b/protocol/src/handshake.rs index 51311e9..aa6f3f1 100644 --- a/protocol/src/handshake.rs +++ b/protocol/src/handshake.rs @@ -1,3 +1,14 @@ +// SPDX-License-Identifier: CC0-1.0 + +//! # BIP-324 V2 Transport Protocol Handshake +//! +//! 1. **Key Exchange**: Both peers generate and exchange public keys using ElligatorSwift encoding. +//! 2. **Garbage**: Optional garbage bytes are sent to obscure traffic patterns. +//! 3. **Decoy Packets**: Optional decoy packets can be sent to further obscure traffic patterns. +//! 4. **Version Authentication**: Version packets are exchanged to negotiate the protocol version for the channel. +//! 5. **Session Establishment**: The secure communication channel is ready for message exchange. +//! ``` + use bitcoin::{ key::Secp256k1, secp256k1::{ @@ -10,13 +21,9 @@ use rand::Rng; use crate::{ CipherSession, Error, OutboundCipher, PacketType, Role, SessionKeyMaterial, - NUM_ELLIGATOR_SWIFT_BYTES, NUM_GARBAGE_TERMINTOR_BYTES, NUM_LENGTH_BYTES, VERSION_CONTENT, + NUM_ELLIGATOR_SWIFT_BYTES, NUM_GARBAGE_TERMINTOR_BYTES, VERSION_CONTENT, }; -/// Initial buffer for decoy and version packets in the handshake. -/// The buffer may have to be expanded if a party is sending large -/// decoy packets. -pub const NUM_INITIAL_HANDSHAKE_BUFFER_BYTES: usize = 4096; // Maximum number of garbage bytes before the terminator. const MAX_NUM_GARBAGE_BYTES: usize = 4095; @@ -27,94 +34,102 @@ pub struct EcdhPoint { elligator_swift: ElligatorSwift, } +/// **Initial state** of the handshake state machine which holds local secret materials. +pub struct Initialized { + point: EcdhPoint, +} + +/// **Second state** after sending the local public key. +pub struct SentKey<'a> { + point: EcdhPoint, + local_garbage: Option<&'a [u8]>, +} + +/// **Third state** after receiving the remote's public key and +/// generating the shared secret materials for the session. +pub struct ReceivedKey<'a> { + session_keys: SessionKeyMaterial, + local_garbage: Option<&'a [u8]>, +} + +/// **Fourth state** after sending the decoy and version packets. +pub struct SentVersion { + cipher: CipherSession, + remote_garbage_terminator: [u8; NUM_GARBAGE_TERMINTOR_BYTES], +} + +/// **Fifth state** after receiving the remote's garbage and garbage terminator. +pub struct ReceivedGarbage<'a> { + cipher: CipherSession, + remote_garbage: Option<&'a [u8]>, +} + +/// Success variants for reading remote garbage. +pub enum GarbageResult<'a> { + /// Successfully found garbage. + FoundGarbage { + handshake: Handshake>, + consumed_bytes: usize, + }, + /// No garbage terminator found, the input buffer needs to be extended. + NeedMoreData(Handshake), +} + +/// Success variants for receiving remote version. +pub enum VersionResult<'a> { + /// Successfully completed handshake. + Complete { cipher: CipherSession }, + /// Packet was a decoy, read the next to see if version. + Decoy(Handshake>), +} + /// Handshake state-machine to establish the secret material in the communication channel. /// -/// A handshake is first initialized to create local materials needed to setup communication -/// channel between an *initiator* and a *responder*. The next step is to call `complete_materials` -/// no matter if initiator or responder, however the responder should already have the -/// necessary materials from their peers request. `complete_materials` creates the response -/// packet to be sent from each peer and `authenticate_garbage_and_version` is then used -/// to verify the handshake. Finally, the `finalized` method is used to consumer the handshake -/// and return a cipher session for further communication on the channel. -pub struct Handshake<'a> { +/// The handshake progresses through multiple states, enforcing the protocol sequence at compile time. +/// +/// 1. `Initialized` - Initial state with local secret materials. +/// 2. `SentKey` - After sending local public key and optional garbage. +/// 3. `ReceivedKey` - After receiving remote's public key. +/// 4. `SentVersion` - After sending local garbage terminator and version packet. +/// 5. Complete - After receiving and authenticating remote's garbage, garbage terminator, decoy packets, and version packet. +pub struct Handshake { /// Bitcoin network both peers are operating on. network: Network, /// Local role in the handshake, initiator or responder. role: Role, - /// Local point for key exchange. - point: EcdhPoint, - /// Optional garbage bytes to send along in handshake. - garbage: Option<&'a [u8]>, - /// Peers expected garbage terminator. - remote_garbage_terminator: Option<[u8; NUM_GARBAGE_TERMINTOR_BYTES]>, - /// Cipher session output. - cipher_session: Option, - /// Decrypted length for next packet. Store state between authentication attempts to avoid resetting ciphers. - current_packet_length_bytes: Option, - /// Processesed buffer index. Store state between authentication attempts to avoid resetting ciphers. - current_buffer_index: usize, + /// State-specific data. + state: State, } -impl<'a> Handshake<'a> { - /// Initialize a V2 transport handshake with a peer. - /// - /// # Arguments - /// - /// * `network` - The bitcoin network which both peers operate on. - /// * `garbage` - Optional garbage to send in handshake. - /// * `buffer` - Packet buffer to send to peer which will include initial materials for handshake + garbage. - /// - /// # Returns - /// - /// An initialized handshake which must be finalized. - /// - /// # Errors - /// - /// Fails if their was an error generating the keypair. +// Methods available in all states. +impl Handshake { + /// Get the network this handshake is operating on. + pub fn network(&self) -> Network { + self.network + } + + /// Get the local role in the handshake. + pub fn role(&self) -> Role { + self.role + } +} + +impl Handshake { + /// Initialize a V2 transport handshake with a remote peer. #[cfg(feature = "std")] - pub fn new( - network: Network, - role: Role, - garbage: Option<&'a [u8]>, - buffer: &mut [u8], - ) -> Result { + pub fn new(network: Network, role: Role) -> Result { let mut rng = rand::thread_rng(); let curve = Secp256k1::signing_only(); - Self::new_with_rng(network, role, garbage, buffer, &mut rng, &curve) + Self::new_with_rng(network, role, &mut rng, &curve) } - /// Initialize a V2 transport handshake with a peer. - /// - /// # Arguments - /// - /// * `network` - The bitcoin network which both peers operate on. - /// * `garbage` - Optional garbage to send in handshake. - /// * `buffer` - Packet buffer to send to peer which will include initial materials for handshake + garbage. - /// * `rng` - Supplied Random Number Generator. - /// * `curve` - Supplied secp256k1 context. - /// - /// # Returns - /// - /// An initialized handshake which must be finalized. - /// - /// # Errors - /// - /// Fails if their was an error generating the keypair. + /// Initialize a V2 transport handshake with remote peer using supplied RNG and secp context. pub fn new_with_rng( network: Network, role: Role, - garbage: Option<&'a [u8]>, - buffer: &mut [u8], rng: &mut impl Rng, curve: &Secp256k1, ) -> Result { - if garbage - .as_ref() - .map_or(false, |g| g.len() > MAX_NUM_GARBAGE_BYTES) - { - return Err(Error::TooMuchGarbage); - }; - let mut secret_key_buffer = [0u8; 32]; rng.fill(&mut secret_key_buffer[..]); let sk = SecretKey::from_slice(&secret_key_buffer)?; @@ -126,54 +141,104 @@ impl<'a> Handshake<'a> { elligator_swift: es, }; - // Bounds check on the output buffer. - let required_bytes = garbage.map_or(NUM_ELLIGATOR_SWIFT_BYTES, |g| { - NUM_ELLIGATOR_SWIFT_BYTES + g.len() - }); - if buffer.len() < required_bytes { - return Err(Error::BufferTooSmall { required_bytes }); - }; + Ok(Handshake { + network, + role, + state: Initialized { point }, + }) + } + + /// Calculate how many bytes send_key() will write to buffer. + pub fn send_key_len(garbage: Option<&[u8]>) -> usize { + NUM_ELLIGATOR_SWIFT_BYTES + garbage.map(|g| g.len()).unwrap_or(0) + } + + /// Send local public key and optional garbage to initiate the handshake. + /// + /// # Parameters + /// + /// * `garbage` - Optional garbage bytes to append after the public key. Limited to 4095 bytes. + /// * `output_buffer` - Buffer to write the key and garbage. Must have sufficient capacity + /// as calculated by `send_key_len()`. + /// + /// # Returns + /// + /// `Ok(Handshake)` - Ready to receive remote peer's key material. + /// + /// # Errors + /// + /// * `TooMuchGarbage` - Garbage exceeds 4095 byte limit. + /// * `BufferTooSmall` - Output buffer insufficient for key + garbage. + pub fn send_key<'a>( + self, + garbage: Option<&'a [u8]>, + output_buffer: &mut [u8], + ) -> Result>, Error> { + // Validate garbage length + if let Some(g) = garbage { + if g.len() > MAX_NUM_GARBAGE_BYTES { + return Err(Error::TooMuchGarbage); + } + } + + let required_len = Self::send_key_len(garbage); + if output_buffer.len() < required_len { + return Err(Error::BufferTooSmall { + required_bytes: required_len, + }); + } + + // Write local ellswift public key. + output_buffer[..NUM_ELLIGATOR_SWIFT_BYTES] + .copy_from_slice(&self.state.point.elligator_swift.to_array()); - buffer[0..64].copy_from_slice(&point.elligator_swift.to_array()); - if let Some(garbage) = garbage { - buffer[64..64 + garbage.len()].copy_from_slice(garbage); + // Write garbage if provided. + if let Some(g) = garbage { + output_buffer[NUM_ELLIGATOR_SWIFT_BYTES..NUM_ELLIGATOR_SWIFT_BYTES + g.len()] + .copy_from_slice(g); } Ok(Handshake { - network, - role, - point, - garbage, - remote_garbage_terminator: None, - cipher_session: None, - current_packet_length_bytes: None, - current_buffer_index: 0, + network: self.network, + role: self.role, + state: SentKey { + point: self.state.point, + local_garbage: garbage, + }, }) } +} - /// Complete the secret material handshake and send the version packet to peer. +impl<'a> Handshake> { + /// Process the remote peer's public key and derive shared session secrets. + /// + /// This is the **second state transition** in the handshake process, moving from + /// `SentKey` to `ReceivedKey` state. The method performs ECDH key exchange using + /// the received remote public key and generates all cryptographic material needed + /// for the secure session. /// - /// # Arguments + /// # Parameters /// - /// * `their_elliswift` - The key material of the remote peer. - /// * `response_buffer` - Buffer to write response for remote peer which includes the garbage terminator and version packet. - /// * `decoys` - Contents for decoy packets sent before version packet. + /// * `their_key` - The remote peer's 64-byte ElligatorSwift encoded public key. + /// + /// # Returns + /// + /// `Ok(Handshake)` - Ready to send version packet with derived session keys. /// /// # Errors /// - /// * `V1Protocol` - The remote is communicating on the V1 protocol instead of V2. Caller can fallback - /// to V1 if they want. - pub fn complete_materials( - &mut self, - their_elliswift: [u8; NUM_ELLIGATOR_SWIFT_BYTES], - response_buffer: &mut [u8], - decoys: Option<&[&[u8]]>, - ) -> Result<(), Error> { - // Short circuit if the remote is sending the V1 protocol network bytes. - // Gives the caller an opportunity to fallback to V1 if they choose. + /// * `V1Protocol` - Remote peer is using the legacy V1 protocol. + /// * `SecretGeneration` - Failed to derive session keys from ECDH. + pub fn receive_key( + self, + their_key: [u8; NUM_ELLIGATOR_SWIFT_BYTES], + ) -> Result>, Error> { + let their_ellswift = ElligatorSwift::from_array(their_key); + + // Check for V1 protocol magic bytes if self.network.magic() == bitcoin::p2p::Magic::from_bytes( - their_elliswift[..4] + their_key[..4] .try_into() .expect("64 byte array to have 4 byte prefix"), ) @@ -181,382 +246,657 @@ impl<'a> Handshake<'a> { return Err(Error::V1Protocol); } - let theirs = ElligatorSwift::from_array(their_elliswift); + // Compute session keys using ECDH + let (initiator_ellswift, responder_ellswift, secret, party) = match self.role { + Role::Initiator => ( + self.state.point.elligator_swift, + their_ellswift, + self.state.point.secret_key, + ElligatorSwiftParty::A, + ), + Role::Responder => ( + their_ellswift, + self.state.point.elligator_swift, + self.state.point.secret_key, + ElligatorSwiftParty::B, + ), + }; + + let session_keys = SessionKeyMaterial::from_ecdh( + initiator_ellswift, + responder_ellswift, + secret, + party, + self.network, + )?; + + Ok(Handshake { + network: self.network, + role: self.role, + state: ReceivedKey { + session_keys, + local_garbage: self.state.local_garbage, + }, + }) + } +} + +impl<'a> Handshake> { + /// Calculate how many bytes send_version() will write to buffer. + pub fn send_version_len(decoys: Option<&[&[u8]]>) -> usize { + let mut len = NUM_GARBAGE_TERMINTOR_BYTES + + OutboundCipher::encryption_buffer_len(VERSION_CONTENT.len()); + + // Add decoy packets length. + if let Some(decoys) = decoys { + for decoy in decoys { + len += OutboundCipher::encryption_buffer_len(decoy.len()); + } + } + + len + } - // Check if the buffer is large enough for the garbage terminator. - if response_buffer.len() < NUM_GARBAGE_TERMINTOR_BYTES { + /// Send garbage terminator, optional decoy packets, and version packet. + /// + /// This is the **third state transition** in the handshake process, moving from + /// `ReceivedKey` to `SentVersion` state. The method initiates encrypted communication + /// by sending the local garbage terminator followed by encrypted packets. + /// + /// # Parameters + /// + /// * `output_buffer` - Buffer to write terminator and encrypted packets. Must have + /// sufficient capacity as calculated by `send_version_len()`. + /// * `decoys` - Optional array of decoy packet contents to send before version packet + /// to help hide the shape of traffic. + /// + /// # Returns + /// + /// `Ok(Handshake)` - Ready to receive and authenticate remote peer's version. + /// + /// # Errors + /// + /// * `BufferTooSmall` - Output buffer insufficient for terminator + packets. + /// * `Decryption` - Cipher operation failed. + pub fn send_version( + self, + output_buffer: &mut [u8], + decoys: Option<&[&[u8]]>, + ) -> Result, Error> { + let required_len = Self::send_version_len(decoys); + if output_buffer.len() < required_len { return Err(Error::BufferTooSmall { - required_bytes: NUM_GARBAGE_TERMINTOR_BYTES, + required_bytes: required_len, }); } - // Line up appropriate materials based on role and some - // garbage terminator haggling. - let materials = match self.role { - Role::Initiator => { - let materials = SessionKeyMaterial::from_ecdh( - self.point.elligator_swift, - theirs, - self.point.secret_key, - ElligatorSwiftParty::A, - self.network, - )?; - response_buffer[..NUM_GARBAGE_TERMINTOR_BYTES] - .copy_from_slice(&materials.initiator_garbage_terminator); - self.remote_garbage_terminator = Some(materials.responder_garbage_terminator); + let mut cipher = CipherSession::new(self.state.session_keys.clone(), self.role); - materials + // Write garbage terminator and determine remote terminator. + let remote_garbage_terminator = match self.role { + Role::Initiator => { + output_buffer[..NUM_GARBAGE_TERMINTOR_BYTES] + .copy_from_slice(&self.state.session_keys.initiator_garbage_terminator); + self.state.session_keys.responder_garbage_terminator } Role::Responder => { - let materials = SessionKeyMaterial::from_ecdh( - theirs, - self.point.elligator_swift, - self.point.secret_key, - ElligatorSwiftParty::B, - self.network, - )?; - response_buffer[..NUM_GARBAGE_TERMINTOR_BYTES] - .copy_from_slice(&materials.responder_garbage_terminator); - self.remote_garbage_terminator = Some(materials.initiator_garbage_terminator); - - materials + output_buffer[..NUM_GARBAGE_TERMINTOR_BYTES] + .copy_from_slice(&self.state.session_keys.responder_garbage_terminator); + self.state.session_keys.initiator_garbage_terminator } }; - let mut cipher_session = CipherSession::new(materials, self.role); - let mut start_index = NUM_GARBAGE_TERMINTOR_BYTES; + let mut bytes_written = NUM_GARBAGE_TERMINTOR_BYTES; + // Local garbage is authenticated in first packet no + // matter if it is a decoy or genuine. + let mut aad = self.state.local_garbage; - // Write any decoy packets and then the version packet. - // The first packet, no matter if decoy or genuinie version packet, needs - // to authenticate the garbage previously sent. if let Some(decoys) = decoys { - for (i, decoy) in decoys.iter().enumerate() { - let end_index = start_index + OutboundCipher::encryption_buffer_len(decoy.len()); - cipher_session.outbound().encrypt( + for decoy in decoys { + let packet_len = OutboundCipher::encryption_buffer_len(decoy.len()); + cipher.outbound().encrypt( decoy, - &mut response_buffer[start_index..end_index], + &mut output_buffer[bytes_written..bytes_written + packet_len], PacketType::Decoy, - if i == 0 { self.garbage } else { None }, + aad, )?; - - start_index = end_index; + aad = None; + bytes_written += packet_len; } } - cipher_session.outbound().encrypt( + let version_packet_len = OutboundCipher::encryption_buffer_len(VERSION_CONTENT.len()); + cipher.outbound().encrypt( &VERSION_CONTENT, - &mut response_buffer[start_index - ..start_index + OutboundCipher::encryption_buffer_len(VERSION_CONTENT.len())], + &mut output_buffer[bytes_written..bytes_written + version_packet_len], PacketType::Genuine, - if decoys.is_none() { self.garbage } else { None }, + aad, )?; - self.cipher_session = Some(cipher_session); - - Ok(()) - } - - /// Authenticate the channel. - /// - /// Designed to be called multiple times until succesful in order to flush - /// garbage and decoy packets from channel. If a `BufferTooSmall ` is - /// returned, the buffer should be extended until `BufferTooSmall ` is - /// not returned. All other errors are fatal for the handshake and it should - /// be completely restarted. - /// - /// # Arguments - /// - /// * `buffer` - Should contain all garbage, the garbage terminator, any decoy packets, and finally the version packet received from peer. - /// * `packet_buffer` - Required memory allocation for decrypting decoy and version packets. - /// - /// # Error - /// - /// * `CiphertextTooSmall` - The buffer did not contain all required information and should be extended (e.g. read more off a socket) and authentication re-tried. - /// * `BufferTooSmall` - The supplied packet_buffer is not large enough for decrypting the decoy and version packets. - /// * `HandshakeOutOfOrder` - The handshake sequence is in a bad state and should be restarted. - /// * `MaxGarbageLength` - Buffer did not contain the garbage terminator, should not be retried. - pub fn authenticate_garbage_and_version( - &mut self, - buffer: &[u8], - packet_buffer: &mut [u8], - ) -> Result<(), Error> { - // Find the end of the garbage. - let (garbage, ciphertext) = self.split_garbage(buffer)?; - - // Flag to track if the version packet has been received to signal the end of the handshake. - let mut found_version_packet = false; - - // The first packet, even if it is a decoy packet, - // is used to authenticate the received garbage through - // the AAD. - if self.current_buffer_index == 0 { - found_version_packet = self.decrypt_packet(ciphertext, packet_buffer, Some(garbage))?; - } - - // If the first packet is a decoy, or if this is a follow up - // authentication attempt, the decoys need to be flushed and - // the version packet found. - // - // The version packet is essentially ignored in the current - // version of the protocol, but it does move the cipher - // states forward. It could be extended in the future. - while !found_version_packet { - found_version_packet = self.decrypt_packet(ciphertext, packet_buffer, None)?; - } - - Ok(()) + Ok(Handshake { + network: self.network, + role: self.role, + state: SentVersion { + cipher, + remote_garbage_terminator, + }, + }) } +} - /// Decrypt the next packet in the buffer while - /// book keeping relevant lengths and indexes. This allows - /// the buffer to be re-processed without throwing off - /// the state of the ciphers. +impl Handshake { + /// Process remote peer's garbage bytes and locate the garbage terminator. + /// + /// This is a critical step in the handshake process, transitioning from + /// `SentVersion` to `ReceivedGarbage` state. The method searches for the remote peer's + /// garbage terminator within the input buffer and separates garbage bytes from the + /// subsequent encrypted packet data. + /// + /// # Parameters + /// + /// * `input_buffer` - Buffer containing remote peer's garbage bytes followed by encrypted + /// packet data. The garbage terminator marks the boundary between these sections. /// /// # Returns /// - /// True if the decrypted packet is the version packet. - fn decrypt_packet( - &mut self, - ciphertext: &[u8], - packet_buffer: &mut [u8], - garbage: Option<&[u8]>, - ) -> Result { - let cipher_session = self - .cipher_session - .as_mut() - .ok_or(Error::HandshakeOutOfOrder)?; - - if self.current_packet_length_bytes.is_none() { - // Bounds check on the input buffer. - if ciphertext.len() < self.current_buffer_index + NUM_LENGTH_BYTES { - return Err(Error::CiphertextTooSmall); + /// * `Ok(GarbageResult::FoundGarbage)` - Successfully located garbage terminator. + /// Contains the next handshake state and number of bytes consumed from input buffer. + /// * `Ok(GarbageResult::NeedMoreData)` - Garbage terminator not found in current buffer. + /// More data needed to locate the terminator boundary. + /// + /// # Errors + /// + /// * `NoGarbageTerminator` - Input exceeds maximum garbage size without finding terminator. + /// Indicates protocol violation or potential attack. + /// + /// # Example + /// + /// ```rust + /// use bip324::{Handshake, GarbageResult, SentVersion}; + /// # use bip324::{Role, Network}; + /// # fn example() -> Result<(), Box> { + /// # let mut handshake = Handshake::new(Network::Bitcoin, Role::Initiator)?; + /// # // ... complete handshake to SentVersion state ... + /// # let handshake: Handshake = todo!(); + /// + /// let mut network_buffer = Vec::new(); + /// + /// loop { + /// // Read more network data... + /// // network_buffer.extend_from_slice(&new_data); + /// + /// match handshake.receive_garbage(&network_buffer)? { + /// GarbageResult::FoundGarbage { handshake: next_state, consumed_bytes } => { + /// // Success! Process remaining data for version packets + /// let remaining_data = &network_buffer[consumed_bytes..]; + /// // Continue with next_state.receive_version()... + /// break; + /// } + /// GarbageResult::NeedMoreData(handshake) => { + /// // Continue accumulating network data + /// continue; + /// } + /// } + /// } + /// # Ok(()) + /// # } + /// ``` + pub fn receive_garbage<'a>(self, input_buffer: &'a [u8]) -> Result, Error> { + match self.split_garbage(input_buffer) { + Ok((garbage, _ciphertext)) => { + let consumed_bytes = garbage.len() + NUM_GARBAGE_TERMINTOR_BYTES; + let handshake = Handshake { + network: self.network, + role: self.role, + state: ReceivedGarbage { + cipher: self.state.cipher, + remote_garbage: if garbage.is_empty() { + None + } else { + Some(garbage) + }, + }, + }; + + Ok(GarbageResult::FoundGarbage { + handshake, + consumed_bytes, + }) } - let packet_length = cipher_session.inbound().decrypt_packet_len( - ciphertext[self.current_buffer_index..self.current_buffer_index + NUM_LENGTH_BYTES] - .try_into() - .expect("Buffer slice must be exactly 3 bytes long"), - ); - // Hang on to decrypted length incase follow up steps fail - // and another authentication attempt is required. Avoids - // throwing off the cipher state. - self.current_packet_length_bytes = Some(packet_length); + Err(Error::CiphertextTooSmall) => Ok(GarbageResult::NeedMoreData(self)), + Err(e) => Err(e), } + } - let packet_length = self - .current_packet_length_bytes - .ok_or(Error::HandshakeOutOfOrder)?; + /// Split buffer on garbage terminator. + fn split_garbage<'b>(&self, buffer: &'b [u8]) -> Result<(&'b [u8], &'b [u8]), Error> { + let terminator = &self.state.remote_garbage_terminator; - // Bounds check on input buffer. - if ciphertext.len() < self.current_buffer_index + NUM_LENGTH_BYTES + packet_length { - return Err(Error::CiphertextTooSmall); + if let Some(index) = buffer + .windows(terminator.len()) + .position(|window| window == terminator) + { + let (garbage, rest) = buffer.split_at(index); + let ciphertext = &rest[terminator.len()..]; + Ok((garbage, ciphertext)) + } else if buffer.len() >= MAX_NUM_GARBAGE_BYTES + NUM_GARBAGE_TERMINTOR_BYTES { + Err(Error::NoGarbageTerminator) + } else { + Err(Error::CiphertextTooSmall) } - let packet_type = cipher_session.inbound().decrypt( - &ciphertext[self.current_buffer_index + NUM_LENGTH_BYTES - ..self.current_buffer_index + NUM_LENGTH_BYTES + packet_length], - packet_buffer, - garbage, - )?; - - // Mark current decryption point in the buffer. - self.current_buffer_index = self.current_buffer_index + NUM_LENGTH_BYTES + packet_length; - self.current_packet_length_bytes = None; + } +} - Ok(matches!(packet_type, PacketType::Genuine)) +impl<'a> Handshake> { + /// Decrypt the packet length from the encrypted length bytes. + pub fn decrypt_packet_len(&mut self, length_bytes: [u8; 3]) -> Result { + Ok(self.state.cipher.inbound().decrypt_packet_len(length_bytes)) } - /// Complete the handshake and return the cipher session for further communication. + /// Decrypt and authenticate the next packet to complete the handshake. /// - /// # Error + /// This is the **final state transition** in the handshake process, completing the + /// BIP-324 protocol by processing the remote peer's version packet. The method performs + /// in-place decryption of encrypted packet data and determines whether the handshake + /// is complete or if additional decoy packets need processing. /// - /// * `HandshakeOutOfOrder` - The handshake sequence is in a bad state and should be restarted. - pub fn finalize(self) -> Result { - let cipher_session = self.cipher_session.ok_or(Error::HandshakeOutOfOrder)?; - Ok(cipher_session) - } - - /// Split off garbage in the given buffer on the remote garbage terminator. + /// # Unique Characteristics + /// + /// **Mutable Buffer Requirement**: Unlike other handshake methods, `receive_version()` + /// requires a mutable input buffer because it performs in-place decryption operations, + /// modifying ciphertext directly to produce plaintext for memory efficiency. + /// + /// # Parameters + /// + /// * `input_buffer` - **Mutable** buffer containing encrypted packet data (excluding + /// the 3-byte length prefix). The ciphertext will be overwritten with plaintext + /// during decryption. Buffer size must match the decrypted packet length. /// /// # Returns /// - /// A `Result` containing the garbage and the remaining ciphertext not including the terminator. + /// * `Ok(VersionResult::Complete { cipher })` - Handshake completed successfully. + /// The returned `CipherSession` is ready for secure message exchange. + /// * `Ok(VersionResult::Decoy(handshake))` - Packet was a decoy, continue processing + /// with the returned handshake state for the next packet. /// - /// # Error + /// # Errors /// - /// * `CiphertextTooSmall` - Buffer did not contain a garbage terminator. - /// * `MaxGarbageLength` - Buffer did not contain the garbage terminator and contains too much garbage, should not be retried. - fn split_garbage<'b>(&self, buffer: &'b [u8]) -> Result<(&'b [u8], &'b [u8]), Error> { - let garbage_term = self - .remote_garbage_terminator - .ok_or(Error::HandshakeOutOfOrder)?; - if let Some(index) = buffer - .windows(garbage_term.len()) - .position(|window| window == garbage_term) - { - Ok((&buffer[..index], &buffer[(index + garbage_term.len())..])) - } else if buffer.len() >= (MAX_NUM_GARBAGE_BYTES + NUM_GARBAGE_TERMINTOR_BYTES) { - Err(Error::NoGarbageTerminator) - } else { - // Terminator not found, the buffer needs more information. - Err(Error::CiphertextTooSmall) + /// * `Decryption` - Packet authentication failed or ciphertext is malformed. + /// * `CiphertextTooSmall` - Ciphertext argument does not contain a whole packet + /// + /// # Example + /// + /// ```rust + /// use bip324::{Handshake, VersionResult, ReceivedGarbage, NUM_LENGTH_BYTES}; + /// # use bip324::{Role, Network}; + /// # fn example() -> Result<(), Box> { + /// # let mut handshake = Handshake::new(Network::Bitcoin, Role::Initiator)?; + /// # // ... complete handshake to ReceivedGarbage state ... + /// # let mut handshake: Handshake = todo!(); + /// # let encrypted_data: &[u8] = todo!(); + /// + /// let mut remaining_data = encrypted_data; + /// + /// // Process packets until version packet found + /// loop { + /// // Read packet length (first 3 bytes) + /// let packet_len = handshake.decrypt_packet_len( + /// remaining_data[..NUM_LENGTH_BYTES].try_into()? + /// )?; + /// + /// // Extract packet data (excluding length prefix) + /// let mut packet = remaining_data[NUM_LENGTH_BYTES..NUM_LENGTH_BYTES + packet_len].to_vec(); + /// remaining_data = &remaining_data[NUM_LENGTH_BYTES + packet_len..]; + /// + /// // Process the packet + /// match handshake.receive_version(&mut packet)? { + /// VersionResult::Complete { cipher } => { + /// // Handshake complete! Ready for secure messaging + /// break; + /// } + /// VersionResult::Decoy(next_handshake) => { + /// // Decoy packet processed, continue with next packet + /// handshake = next_handshake; + /// } + /// } + /// } + /// # Ok(()) + /// # } + /// ``` + pub fn receive_version(mut self, input_buffer: &mut [u8]) -> Result, Error> { + // Take the garbage on first call to ensure AAD is only used once. + let aad = self.state.remote_garbage.take(); + + let (packet_type, _) = self + .state + .cipher + .inbound() + .decrypt_in_place(input_buffer, aad)?; + + match packet_type { + PacketType::Genuine => Ok(VersionResult::Complete { + cipher: self.state.cipher, + }), + PacketType::Decoy => Ok(VersionResult::Decoy(self)), } } } #[cfg(all(test, feature = "std"))] mod tests { - use bitcoin::secp256k1::ellswift::{ElligatorSwift, ElligatorSwiftParty}; - use core::str::FromStr; - use hex::prelude::*; - use std::{string::ToString, vec}; + use std::vec; + + use crate::NUM_LENGTH_BYTES; use super::*; + // Test that the handshake completes successfully with garbage and decoy packets + // from both parties. This is a comprehensive integration test of the full protocol. #[test] - fn test_initial_message() { - let mut message = [0u8; 64]; - let handshake = - Handshake::new(Network::Bitcoin, Role::Initiator, None, &mut message).unwrap(); - let message = message.to_lower_hex_string(); - let es = handshake.point.elligator_swift.to_string(); - assert!(message.contains(&es)) - } + fn test_handshake() { + let initiator_garbage = vec![1u8, 2u8, 3u8]; + let responder_garbage = vec![4u8, 5u8]; + + let init_handshake = + Handshake::::new(Network::Bitcoin, Role::Initiator).unwrap(); + + // Send initiator key + garbage. + let mut init_buffer = + vec![0u8; Handshake::::send_key_len(Some(&initiator_garbage))]; + let init_handshake = init_handshake + .send_key(Some(&initiator_garbage), &mut init_buffer) + .unwrap(); - #[test] - fn test_message_response() { - let mut message = [0u8; 64]; - Handshake::new(Network::Bitcoin, Role::Initiator, None, &mut message).unwrap(); - - let mut response_message = [0u8; 100]; - let mut response = Handshake::new( - Network::Bitcoin, - Role::Responder, - None, - &mut response_message, - ) - .unwrap(); - - response - .complete_materials(message, &mut response_message[64..], None) + let resp_handshake = + Handshake::::new(Network::Bitcoin, Role::Responder).unwrap(); + + // Send responder key + garbage. + let mut resp_buffer = + vec![0u8; Handshake::::send_key_len(Some(&responder_garbage))]; + let resp_handshake = resp_handshake + .send_key(Some(&responder_garbage), &mut resp_buffer) .unwrap(); - } - #[test] - fn test_shared_secret() { - // Test that SessionKeyMaterial::from_ecdh produces expected garbage terminators - let alice = - SecretKey::from_str("61062ea5071d800bbfd59e2e8b53d47d194b095ae5a4df04936b49772ef0d4d7") + // Initiator receives responder's key. + let init_handshake = init_handshake + .receive_key(resp_buffer[..NUM_ELLIGATOR_SWIFT_BYTES].try_into().unwrap()) + .unwrap(); + + // Responder receives initiator's key. + let resp_handshake = resp_handshake + .receive_key(init_buffer[..NUM_ELLIGATOR_SWIFT_BYTES].try_into().unwrap()) + .unwrap(); + + // Create decoy packets for both sides. + let init_decoy1 = vec![0xDE, 0xAD, 0xBE, 0xEF]; + let init_decoy2 = vec![0xCA, 0xFE, 0xBA, 0xBE, 0x00, 0x01]; + let init_decoys = vec![init_decoy1.as_slice(), init_decoy2.as_slice()]; + + let resp_decoy1 = vec![0xAB, 0xCD, 0xEF]; + let resp_decoys = vec![resp_decoy1.as_slice()]; + + // Initiator sends decoys and version. + let mut init_version_buffer = + vec![0u8; Handshake::>::send_version_len(Some(&init_decoys))]; + let init_handshake = init_handshake + .send_version(&mut init_version_buffer, Some(&init_decoys)) + .unwrap(); + + // Responder sends decoys and version. + let mut resp_version_buffer = + vec![0u8; Handshake::>::send_version_len(Some(&resp_decoys))]; + let resp_handshake = resp_handshake + .send_version(&mut resp_version_buffer, Some(&resp_decoys)) + .unwrap(); + + // Initiator processes responder's response + let full_resp_message = [&responder_garbage[..], &resp_version_buffer[..]].concat(); + + // First, find the garbage terminator + let (mut init_handshake, consumed) = + match init_handshake.receive_garbage(&full_resp_message).unwrap() { + GarbageResult::FoundGarbage { + handshake, + consumed_bytes, + } => (handshake, consumed_bytes), + GarbageResult::NeedMoreData(_) => panic!("Should have found garbage terminator"), + }; + + // Process the encrypted packets (1 decoy + 1 version) + let mut remaining = &full_resp_message[consumed..]; + + // First packet is a decoy + let packet_len = init_handshake + .decrypt_packet_len(remaining[..NUM_LENGTH_BYTES].try_into().unwrap()) + .unwrap(); + let mut packet = remaining[NUM_LENGTH_BYTES..NUM_LENGTH_BYTES + packet_len].to_vec(); + remaining = &remaining[NUM_LENGTH_BYTES + packet_len..]; + + init_handshake = match init_handshake.receive_version(&mut packet).unwrap() { + VersionResult::Decoy(handshake) => handshake, + VersionResult::Complete { .. } => panic!("First packet should be decoy"), + }; + + // Second packet is the version + let packet_len = init_handshake + .decrypt_packet_len(remaining[..NUM_LENGTH_BYTES].try_into().unwrap()) + .unwrap(); + let mut packet = remaining[NUM_LENGTH_BYTES..NUM_LENGTH_BYTES + packet_len].to_vec(); + + match init_handshake.receive_version(&mut packet).unwrap() { + VersionResult::Complete { .. } => {} // Success! + VersionResult::Decoy(_) => panic!("Second packet should be version"), + }; + + // Responder processes initiator's response + let full_init_message = [&initiator_garbage[..], &init_version_buffer[..]].concat(); + + // First, find the garbage terminator + let (mut resp_handshake, consumed) = + match resp_handshake.receive_garbage(&full_init_message).unwrap() { + GarbageResult::FoundGarbage { + handshake, + consumed_bytes, + } => (handshake, consumed_bytes), + GarbageResult::NeedMoreData(_) => panic!("Should have found garbage terminator"), + }; + + // Process the encrypted packets (2 decoys + 1 version) + let mut remaining = &full_init_message[consumed..]; + + // First two packets are decoys + for i in 0..2 { + let packet_len = resp_handshake + .decrypt_packet_len(remaining[..NUM_LENGTH_BYTES].try_into().unwrap()) .unwrap(); - let elliswift_alice = ElligatorSwift::from_str("ec0adff257bbfe500c188c80b4fdd640f6b45a482bbc15fc7cef5931deff0aa186f6eb9bba7b85dc4dcc28b28722de1e3d9108b985e2967045668f66098e475b").unwrap(); - let elliswift_bob = ElligatorSwift::from_str("a4a94dfce69b4a2a0a099313d10f9f7e7d649d60501c9e1d274c300e0d89aafaffffffffffffffffffffffffffffffffffffffffffffffffffffffff8faf88d5").unwrap(); - let session_keys = SessionKeyMaterial::from_ecdh( - elliswift_alice, - elliswift_bob, - alice, - ElligatorSwiftParty::A, - Network::Bitcoin, - ) - .unwrap(); - // Just verify the garbage terminators which are the only public fields we need - assert_eq!( - "faef555dfcdb936425d84aba524758f3", - session_keys - .initiator_garbage_terminator - .to_lower_hex_string() - ); - assert_eq!( - "02cb8ff24307a6e27de3b4e7ea3fa65b", - session_keys - .responder_garbage_terminator - .to_lower_hex_string() - ); + let mut packet = remaining[NUM_LENGTH_BYTES..NUM_LENGTH_BYTES + packet_len].to_vec(); + remaining = &remaining[NUM_LENGTH_BYTES + packet_len..]; + + resp_handshake = match resp_handshake.receive_version(&mut packet).unwrap() { + VersionResult::Decoy(handshake) => handshake, + VersionResult::Complete { .. } => panic!("Packet {} should be decoy", i), + }; + } + + // Third packet is the version + let packet_len = resp_handshake + .decrypt_packet_len(remaining[..NUM_LENGTH_BYTES].try_into().unwrap()) + .unwrap(); + let mut packet = remaining[NUM_LENGTH_BYTES..NUM_LENGTH_BYTES + packet_len].to_vec(); + + match resp_handshake.receive_version(&mut packet).unwrap() { + VersionResult::Complete { .. } => {} // Success! + VersionResult::Decoy(_) => panic!("Third packet should be version"), + }; } + // Test that send_key properly validates garbage length limits (max 4095 bytes) + // and buffer size requirements. #[test] - fn test_handshake_garbage_length_check() { - let mut rng = rand::thread_rng(); - let curve = Secp256k1::new(); - let mut handshake_buffer = [0u8; NUM_ELLIGATOR_SWIFT_BYTES + MAX_NUM_GARBAGE_BYTES]; - - // Test with valid garbage length. + fn test_handshake_send_key() { + // Test with valid garbage length let valid_garbage = vec![0u8; MAX_NUM_GARBAGE_BYTES]; - let result = Handshake::new_with_rng( - Network::Bitcoin, - Role::Initiator, - Some(&valid_garbage), - &mut handshake_buffer, - &mut rng, - &curve, - ); + let handshake = Handshake::::new(Network::Bitcoin, Role::Initiator).unwrap(); + let mut buffer = vec![0u8; NUM_ELLIGATOR_SWIFT_BYTES + MAX_NUM_GARBAGE_BYTES]; + let result = handshake.send_key(Some(&valid_garbage), &mut buffer); assert!(result.is_ok()); - // Test with garbage length exceeding MAX_NUM_GARBAGE_BYTES. + // Test with garbage length exceeding MAX_NUM_GARBAGE_BYTES let too_much_garbage = vec![0u8; MAX_NUM_GARBAGE_BYTES + 1]; - let result = Handshake::new_with_rng( - Network::Bitcoin, - Role::Initiator, - Some(&too_much_garbage), - &mut handshake_buffer, - &mut rng, - &curve, - ); + let handshake = Handshake::::new(Network::Bitcoin, Role::Initiator).unwrap(); + let result = handshake.send_key(Some(&too_much_garbage), &mut buffer); assert!(matches!(result, Err(Error::TooMuchGarbage))); - // Test too small of buffer. + // Test too small of buffer let buffer_size = NUM_ELLIGATOR_SWIFT_BYTES + valid_garbage.len() - 1; let mut too_small_buffer = vec![0u8; buffer_size]; - let result = Handshake::new_with_rng( - Network::Bitcoin, - Role::Initiator, - Some(&valid_garbage), - &mut too_small_buffer, - &mut rng, - &curve, - ); - + let handshake = Handshake::::new(Network::Bitcoin, Role::Initiator).unwrap(); + let result = handshake.send_key(Some(&valid_garbage), &mut too_small_buffer); assert!( matches!(result, Err(Error::BufferTooSmall { required_bytes }) if required_bytes == NUM_ELLIGATOR_SWIFT_BYTES + valid_garbage.len()), "Expected BufferTooSmall with correct size" ); - // Test with no garbage. - let result = Handshake::new_with_rng( - Network::Bitcoin, - Role::Initiator, - None, - &mut handshake_buffer, - &mut rng, - &curve, - ); + // Test with no garbage + let handshake = Handshake::::new(Network::Bitcoin, Role::Initiator).unwrap(); + let result = handshake.send_key(None, &mut buffer); assert!(result.is_ok()); } + // Test the NeedMoreData scenario where receive_garbage is called with partial data. + // The local peer doesn't know how much garbage the remote will send, so just needs + // to pull and do some buffer mangament. #[test] - fn test_handshake_no_garbage_terminator() { - let mut handshake_buffer = [0u8; NUM_ELLIGATOR_SWIFT_BYTES]; - let mut rng = rand::thread_rng(); - let curve = Secp256k1::signing_only(); + fn test_handshake_receive_garbage_buffer() { + let init_handshake = + Handshake::::new(Network::Bitcoin, Role::Initiator).unwrap(); + let resp_handshake = + Handshake::::new(Network::Bitcoin, Role::Responder).unwrap(); + + let mut init_buffer = vec![0u8; NUM_ELLIGATOR_SWIFT_BYTES]; + let init_handshake = init_handshake.send_key(None, &mut init_buffer).unwrap(); + + let mut resp_buffer = vec![0u8; NUM_ELLIGATOR_SWIFT_BYTES]; + let resp_handshake = resp_handshake.send_key(None, &mut resp_buffer).unwrap(); + + let init_handshake = init_handshake + .receive_key(resp_buffer[..NUM_ELLIGATOR_SWIFT_BYTES].try_into().unwrap()) + .unwrap(); + let resp_handshake = resp_handshake + .receive_key(init_buffer[..NUM_ELLIGATOR_SWIFT_BYTES].try_into().unwrap()) + .unwrap(); + + let mut init_version_buffer = + vec![0u8; Handshake::>::send_version_len(None)]; + let _init_handshake = init_handshake + .send_version(&mut init_version_buffer, None) + .unwrap(); + + let mut resp_version_buffer = + vec![0u8; Handshake::>::send_version_len(None)]; + let resp_handshake = resp_handshake + .send_version(&mut resp_version_buffer, None) + .unwrap(); + + // Test streaming scenario with receive_garbage + let partial_data_1 = &init_version_buffer[..1]; + let returned_handshake = match resp_handshake.receive_garbage(partial_data_1).unwrap() { + GarbageResult::NeedMoreData(handshake) => handshake, + GarbageResult::FoundGarbage { .. } => { + panic!("Should have needed more data with 1 byte") + } + }; + + // Feed a bit more data - still probably not enough. + let partial_data_2 = &init_version_buffer[..5]; + let returned_handshake = match returned_handshake.receive_garbage(partial_data_2).unwrap() { + GarbageResult::NeedMoreData(handshake) => handshake, + GarbageResult::FoundGarbage { .. } => { + panic!("Should have needed more data with 5 bytes") + } + }; + + // Now provide enough data to find the garbage terminator. + // Since there's no garbage, the terminator should be at the beginning. + let (mut handshake, consumed) = match returned_handshake + .receive_garbage(&init_version_buffer) + .unwrap() + { + GarbageResult::FoundGarbage { + handshake, + consumed_bytes, + } => (handshake, consumed_bytes), + GarbageResult::NeedMoreData(_) => { + panic!("Should have found garbage terminator with full data") + } + }; - let mut handshake = Handshake::new_with_rng( - Network::Bitcoin, - Role::Initiator, - None, - &mut handshake_buffer, - &mut rng, - &curve, - ) - .expect("Handshake creation should succeed"); + // Process the version packet + let remaining = &init_version_buffer[consumed..]; + let packet_len = handshake + .decrypt_packet_len(remaining[..NUM_LENGTH_BYTES].try_into().unwrap()) + .unwrap(); + let mut packet = remaining[NUM_LENGTH_BYTES..NUM_LENGTH_BYTES + packet_len].to_vec(); + + match handshake.receive_version(&mut packet).unwrap() { + VersionResult::Complete { .. } => {} // Success! + VersionResult::Decoy(_) => panic!("Should be version packet"), + }; + } - // Skipping material creation and just placing a mock terminator. - handshake.remote_garbage_terminator = Some([0xFF; NUM_GARBAGE_TERMINTOR_BYTES]); + // Test split_garbage error conditions. + // + // 1. NoGarbageTerminator - when buffer exceeds max size without finding terminator + // 2. CiphertextTooSmall - when buffer is too short to possibly contain terminator + #[test] + fn test_handshake_split_garbage() { + // Create a handshake and bring it to the SentVersion state to test split_garbage + let handshake = Handshake::::new(Network::Bitcoin, Role::Initiator).unwrap(); + let mut buffer = vec![0u8; NUM_ELLIGATOR_SWIFT_BYTES]; + let handshake = handshake.send_key(None, &mut buffer).unwrap(); + + // Create a fake peer key to receive + let fake_peer_key = [0u8; NUM_ELLIGATOR_SWIFT_BYTES]; + let handshake = handshake.receive_key(fake_peer_key).unwrap(); - // Test with a buffer that is too long. + // Send version to get to SentVersion state + let mut version_buffer = vec![0u8; 1024]; + let handshake = handshake.send_version(&mut version_buffer, None).unwrap(); + + // Test with a buffer that is too long (should fail to find terminator) let test_buffer = vec![0; MAX_NUM_GARBAGE_BYTES + NUM_GARBAGE_TERMINTOR_BYTES]; let result = handshake.split_garbage(&test_buffer); assert!(matches!(result, Err(Error::NoGarbageTerminator))); - // Test with a buffer that's just short of the required length. + // Test with a buffer that's just short of the required length let short_buffer = vec![0; MAX_NUM_GARBAGE_BYTES + NUM_GARBAGE_TERMINTOR_BYTES - 1]; let result = handshake.split_garbage(&short_buffer); assert!(matches!(result, Err(Error::CiphertextTooSmall))); } + + // Test that receive_key detects V1 protocol when peer's key starts with network magic. + #[test] + fn test_v1_protocol_detection() { + // Test that receive_key properly detects V1 protocol magic bytes + let handshake = Handshake::::new(Network::Bitcoin, Role::Initiator).unwrap(); + let mut buffer = vec![0u8; NUM_ELLIGATOR_SWIFT_BYTES]; + let handshake = handshake.send_key(None, &mut buffer).unwrap(); + + // Create a key that starts with Bitcoin mainnet magic bytes + let mut v1_key = [0u8; NUM_ELLIGATOR_SWIFT_BYTES]; + v1_key[..4].copy_from_slice(&Network::Bitcoin.magic().to_bytes()); + + let result = handshake.receive_key(v1_key); + assert!(matches!(result, Err(Error::V1Protocol))); + + // Test with different networks + let handshake = Handshake::::new(Network::Testnet, Role::Responder).unwrap(); + let handshake = handshake.send_key(None, &mut buffer).unwrap(); + + let mut v1_testnet_key = [0u8; NUM_ELLIGATOR_SWIFT_BYTES]; + v1_testnet_key[..4].copy_from_slice(&Network::Testnet.magic().to_bytes()); + + let result = handshake.receive_key(v1_testnet_key); + assert!(matches!(result, Err(Error::V1Protocol))); + } } diff --git a/protocol/src/io.rs b/protocol/src/io.rs index 982f2b7..40d34ab 100644 --- a/protocol/src/io.rs +++ b/protocol/src/io.rs @@ -20,8 +20,9 @@ use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use crate::{ + handshake::{self, GarbageResult, VersionResult}, Error, Handshake, InboundCipher, OutboundCipher, PacketType, Role, NUM_ELLIGATOR_SWIFT_BYTES, - NUM_GARBAGE_TERMINTOR_BYTES, NUM_INITIAL_HANDSHAKE_BUFFER_BYTES, VERSION_CONTENT, + NUM_GARBAGE_TERMINTOR_BYTES, }; /// A decrypted BIP324 payload with its packet type. @@ -100,6 +101,23 @@ impl From for ProtocolError { } } +#[cfg(feature = "std")] +impl ProtocolError { + /// Create an EOF error that suggests retrying with V1 protocol. + /// + /// This is used when the remote peer closes the connection during handshake, + /// which often indicates they don't support the V2 protocol. + fn eof() -> Self { + ProtocolError::Io( + std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "Remote peer closed connection during handshake", + ), + ProtocolFailureSuggestion::RetryV1, + ) + } +} + #[cfg(feature = "std")] impl std::error::Error for ProtocolError { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { @@ -141,6 +159,8 @@ pub struct AsyncProtocol { impl AsyncProtocol { /// New protocol session which completes the initial handshake and returns a handler. /// + /// This function is *not* cancellation safe. + /// /// # Arguments /// /// * `network` - Network which both parties are operating on. @@ -171,95 +191,93 @@ impl AsyncProtocol { R: AsyncRead + Unpin + Send, W: AsyncWrite + Unpin + Send, { - let garbage_len = match garbage { - Some(slice) => slice.len(), - None => 0, - }; - let mut ellswift_buffer = vec![0u8; NUM_ELLIGATOR_SWIFT_BYTES + garbage_len]; - let mut handshake = Handshake::new(network, role, garbage, &mut ellswift_buffer)?; + let handshake = Handshake::::new(network, role)?; - // Send initial key to remote. - writer.write_all(&ellswift_buffer).await?; + // Send local public key and optional garbage. + let key_buffer_len = Handshake::::send_key_len(garbage); + let mut key_buffer = vec![0u8; key_buffer_len]; + let handshake = handshake.send_key(garbage, &mut key_buffer)?; + writer.write_all(&key_buffer).await?; writer.flush().await?; - // Read remote's initial key. - let mut remote_ellswift_buffer = [0u8; 64]; + // Read remote's public key. + let mut remote_ellswift_buffer = [0u8; NUM_ELLIGATOR_SWIFT_BYTES]; reader.read_exact(&mut remote_ellswift_buffer).await?; + let handshake = handshake.receive_key(remote_ellswift_buffer)?; - let num_version_packet_bytes = OutboundCipher::encryption_buffer_len(VERSION_CONTENT.len()); - let num_decoy_packets_bytes: usize = match decoys { - Some(decoys) => decoys - .iter() - .map(|decoy| OutboundCipher::encryption_buffer_len(decoy.len())) - .sum(), - None => 0, - }; - - // Complete materials and send terminator to remote. - // Not exposing decoy packets yet. - let mut terminator_and_version_buffer = - vec![ - 0u8; - NUM_GARBAGE_TERMINTOR_BYTES + num_version_packet_bytes + num_decoy_packets_bytes - ]; - handshake.complete_materials( - remote_ellswift_buffer, - &mut terminator_and_version_buffer, - decoys, - )?; - writer.write_all(&terminator_and_version_buffer).await?; + // Send garbage terminator, decoys, and version. + let version_buffer_len = Handshake::::send_version_len(decoys); + let mut version_buffer = vec![0u8; version_buffer_len]; + let handshake = handshake.send_version(&mut version_buffer, decoys)?; + writer.write_all(&version_buffer).await?; writer.flush().await?; - // Receive and authenticate remote garbage and version. - // Keep pulling bytes from the buffer until the garbage is flushed. - let mut remote_garbage_and_version_buffer = - Vec::with_capacity(NUM_INITIAL_HANDSHAKE_BUFFER_BYTES); - let mut packet_buffer = vec![0u8; NUM_INITIAL_HANDSHAKE_BUFFER_BYTES]; - - loop { - let mut temp_buffer = [0u8; NUM_INITIAL_HANDSHAKE_BUFFER_BYTES]; - match reader.read(&mut temp_buffer).await { - // No data available right now, continue. - Ok(0) => { - continue; + // Receive and process garbage terminator + let mut garbage_buffer = vec![0u8; NUM_GARBAGE_TERMINTOR_BYTES]; + reader.read_exact(&mut garbage_buffer).await?; + + let mut handshake = handshake; + let (mut handshake, garbage_bytes) = loop { + match handshake.receive_garbage(&garbage_buffer) { + Ok(GarbageResult::FoundGarbage { + handshake, + consumed_bytes, + }) => { + break (handshake, consumed_bytes); } - Ok(bytes_read) => { - remote_garbage_and_version_buffer.extend_from_slice(&temp_buffer[..bytes_read]); - - match handshake.authenticate_garbage_and_version( - &remote_garbage_and_version_buffer, - &mut packet_buffer, - ) { - Ok(()) => break, - // Not enough data, continue reading. - Err(Error::CiphertextTooSmall) => continue, - Err(Error::BufferTooSmall { required_bytes }) => { - packet_buffer.resize(required_bytes, 0); - continue; + Ok(GarbageResult::NeedMoreData(h)) => { + handshake = h; + // Use small chunks to avoid reading past garbage, decoys, and version. + let mut temp = vec![0u8; NUM_GARBAGE_TERMINTOR_BYTES]; + match reader.read(&mut temp).await { + Ok(0) => return Err(ProtocolError::eof()), + Ok(n) => { + garbage_buffer.extend_from_slice(&temp[..n]); } - Err(e) => return Err(ProtocolError::Internal(e)), + Err(e) => return Err(ProtocolError::from(e)), } } - Err(e) => match e.kind() { - // No data available or interrupted, retry. - std::io::ErrorKind::WouldBlock | std::io::ErrorKind::Interrupted => { - continue; - } - _ => return Err(ProtocolError::Io(e, ProtocolFailureSuggestion::Abort)), - }, + Err(e) => return Err(ProtocolError::Internal(e)), } - } - - let cipher_session = handshake.finalize()?; - let (inbound_cipher, outbound_cipher) = cipher_session.into_split(); + }; - Ok(Self { - reader: AsyncProtocolReader { - inbound_cipher, - state: DecryptState::init_reading_length(), - }, - writer: AsyncProtocolWriter { outbound_cipher }, - }) + // Process remaining bytes and read version packets. + let mut version_buffer = garbage_buffer[garbage_bytes..].to_vec(); + loop { + // Decrypt packet length. + if version_buffer.len() < 3 { + let old_len = version_buffer.len(); + version_buffer.resize(3, 0); + reader.read_exact(&mut version_buffer[old_len..]).await?; + } + let packet_len = + handshake.decrypt_packet_len(version_buffer[..3].try_into().unwrap())?; + version_buffer.drain(..3); + + // Process packet. + if version_buffer.len() < packet_len { + let old_len = version_buffer.len(); + version_buffer.resize(packet_len, 0); + reader.read_exact(&mut version_buffer[old_len..]).await?; + } + match handshake.receive_version(&mut version_buffer) { + Ok(VersionResult::Complete { cipher }) => { + let (inbound_cipher, outbound_cipher) = cipher.into_split(); + return Ok(Self { + reader: AsyncProtocolReader { + inbound_cipher, + state: DecryptState::init_reading_length(), + }, + writer: AsyncProtocolWriter { outbound_cipher }, + }); + } + Ok(VersionResult::Decoy(h)) => { + handshake = h; + version_buffer.drain(..packet_len); + } + Err(e) => return Err(ProtocolError::Internal(e)), + } + } } /// Read reference for packet reading operations. @@ -365,7 +383,8 @@ impl AsyncProtocolReader { self.inbound_cipher .decrypt(packet_bytes, &mut plaintext_buffer, None)?; self.state = DecryptState::init_reading_length(); - return Ok(Payload::new(plaintext_buffer, packet_type)); + // Skip the header byte (first byte) which contains the packet type + return Ok(Payload::new(plaintext_buffer[1..].to_vec(), packet_type)); } } } diff --git a/protocol/src/lib.rs b/protocol/src/lib.rs index 70910d7..99224c7 100644 --- a/protocol/src/lib.rs +++ b/protocol/src/lib.rs @@ -39,7 +39,10 @@ use bitcoin_hashes::{hkdf, sha256, Hkdf}; pub use bitcoin::Network; -pub use handshake::{Handshake, NUM_INITIAL_HANDSHAKE_BUFFER_BYTES}; +pub use handshake::{ + GarbageResult, Handshake, Initialized, ReceivedGarbage, ReceivedKey, SentKey, SentVersion, + VersionResult, +}; // Re-exports from io module (async I/O types for backwards compatibility) #[cfg(any(feature = "futures", feature = "tokio"))] pub use io::{ @@ -82,8 +85,6 @@ pub enum Error { /// The remote sent the maximum amount of garbage bytes without /// a garbage terminator in the handshake. NoGarbageTerminator, - /// A handshake step was not completed in the proper order. - HandshakeOutOfOrder, /// The remote peer is communicating on the V1 protocol. V1Protocol, /// Not able to generate secret material. @@ -108,7 +109,6 @@ impl fmt::Display for Error { Error::NoGarbageTerminator => { write!(f, "More than 4095 bytes of garbage recieved in the handshake before a terminator was sent.") } - Error::HandshakeOutOfOrder => write!(f, "Handshake flow out of sequence."), Error::SecretGeneration(e) => write!(f, "Cannot generate secrets: {e:?}."), Error::Decryption(e) => write!(f, "Decrytion error: {e:?}."), Error::V1Protocol => write!(f, "The remote peer is communicating on the V1 protocol."), @@ -127,7 +127,6 @@ impl std::error::Error for Error { Error::CiphertextTooSmall => None, Error::BufferTooSmall { .. } => None, Error::NoGarbageTerminator => None, - Error::HandshakeOutOfOrder => None, Error::V1Protocol => None, Error::SecretGeneration(e) => Some(e), Error::Decryption(e) => Some(e), @@ -312,7 +311,7 @@ pub struct InboundCipher { } impl InboundCipher { - /// Decrypt the length of the next inbound packet. + /// Decrypt the length of the packet's payload. /// /// Note that this returns the length of the remaining packet data /// to be read from the stream (header + contents + tag), not just the contents. @@ -340,11 +339,50 @@ impl InboundCipher { packet_len - NUM_TAG_BYTES } + /// Decrypt an inbound packet in-place. + /// + /// # Arguments + /// + /// * `ciphertext` - A mutable buffer containing the packet from the peer excluding + /// the first 3 length bytes. It should contain the header, contents, and authentication tag. + /// This buffer will be modified in-place during decryption. + /// * `aad` - Optional associated authenticated data. + /// + /// # Returns + /// + /// A `Result` containing: + /// * `Ok((PacketType, &[u8]))`: A tuple of the packet type and a slice pointing to the + /// decrypted plaintext within the input buffer. The first byte of the slice is the + /// header byte containing protocol flags. + /// * `Err(Error)`: An error that occurred during decryption. + /// + /// # Errors + /// + /// * `CiphertextTooSmall` - Ciphertext argument does not contain a whole packet. + /// * Decryption errors for any failures such as a tag mismatch. + pub fn decrypt_in_place<'a>( + &mut self, + ciphertext: &'a mut [u8], + aad: Option<&[u8]>, + ) -> Result<(PacketType, &'a [u8]), Error> { + let auth = aad.unwrap_or_default(); + // Check minimum size of ciphertext. + if ciphertext.len() < NUM_TAG_BYTES { + return Err(Error::CiphertextTooSmall); + } + let (msg, tag) = ciphertext.split_at_mut(ciphertext.len() - NUM_TAG_BYTES); + + self.packet_cipher + .decrypt(auth, msg, tag.try_into().expect("16 byte tag"))?; + + Ok((PacketType::from_byte(&msg[0]), msg)) + } + /// Decrypt an inbound packet. /// /// # Arguments /// - /// * `packet_payload` - The packet from the peer excluding the first 3 length bytes. It should contain + /// * `ciphertext` - The packet from the peer excluding the first 3 length bytes. It should contain /// the header, contents, and authentication tag. /// * `plaintext_buffer` - Mutable buffer to write plaintext. Note that the first byte is the header byte /// containing protocol flags. @@ -363,16 +401,16 @@ impl InboundCipher { /// * Decryption errors for any failures such as a tag mismatch. pub fn decrypt( &mut self, - packet_payload: &[u8], + ciphertext: &[u8], plaintext_buffer: &mut [u8], aad: Option<&[u8]>, ) -> Result { let auth = aad.unwrap_or_default(); // Check minimum size of ciphertext. - if packet_payload.len() < NUM_TAG_BYTES { + if ciphertext.len() < NUM_TAG_BYTES { return Err(Error::CiphertextTooSmall); } - let (msg, tag) = packet_payload.split_at(packet_payload.len() - NUM_TAG_BYTES); + let (msg, tag) = ciphertext.split_at(ciphertext.len() - NUM_TAG_BYTES); // Check that the contents buffer is large enough. if plaintext_buffer.len() < msg.len() { return Err(Error::BufferTooSmall { @@ -399,7 +437,7 @@ pub struct OutboundCipher { impl OutboundCipher { /// Calculate the required encryption buffer length for given plaintext length. - pub fn encryption_buffer_len(plaintext_len: usize) -> usize { + pub const fn encryption_buffer_len(plaintext_len: usize) -> usize { plaintext_len + NUM_PACKET_OVERHEAD_BYTES } @@ -408,7 +446,7 @@ impl OutboundCipher { /// # Arguments /// /// * `plaintext` - Plaintext contents to be encrypted. - /// * `packet_buffer` - Buffer to write packet bytes to which must have enough capacity + /// * `ciphertext_buffer` - Buffer to write packet bytes to which must have enough capacity /// as calculated by `encryption_buffer_len(plaintext.len())`. /// * `packet_type` - Is this a genuine packet or a decoy. /// * `aad` - Optional associated authenticated data. @@ -420,12 +458,12 @@ impl OutboundCipher { pub fn encrypt( &mut self, plaintext: &[u8], - packet_buffer: &mut [u8], + ciphertext_buffer: &mut [u8], packet_type: PacketType, aad: Option<&[u8]>, ) -> Result<(), Error> { // Validate buffer capacity. - if packet_buffer.len() < Self::encryption_buffer_len(plaintext.len()) { + if ciphertext_buffer.len() < Self::encryption_buffer_len(plaintext.len()) { return Err(Error::BufferTooSmall { required_bytes: Self::encryption_buffer_len(plaintext.len()), }); @@ -437,14 +475,15 @@ impl OutboundCipher { let plaintext_end_index = plaintext_start_index + plaintext_length; // Set header byte. - packet_buffer[header_index] = packet_type.to_byte(); - packet_buffer[plaintext_start_index..plaintext_end_index].copy_from_slice(plaintext); + ciphertext_buffer[header_index] = packet_type.to_byte(); + ciphertext_buffer[plaintext_start_index..plaintext_end_index].copy_from_slice(plaintext); // Encrypt header byte and plaintext in place and produce authentication tag. let auth = aad.unwrap_or_default(); - let tag = self - .packet_cipher - .encrypt(auth, &mut packet_buffer[header_index..plaintext_end_index]); + let tag = self.packet_cipher.encrypt( + auth, + &mut ciphertext_buffer[header_index..plaintext_end_index], + ); // Encrypt plaintext length. let mut content_len = [0u8; 3]; @@ -452,8 +491,8 @@ impl OutboundCipher { self.length_cipher.crypt(&mut content_len); // Copy over encrypted length and the tag to the final packet (plaintext already encrypted). - packet_buffer[0..NUM_LENGTH_BYTES].copy_from_slice(&content_len); - packet_buffer[plaintext_end_index..(plaintext_end_index + NUM_TAG_BYTES)] + ciphertext_buffer[0..NUM_LENGTH_BYTES].copy_from_slice(&content_len); + ciphertext_buffer[plaintext_end_index..(plaintext_end_index + NUM_TAG_BYTES)] .copy_from_slice(&tag); Ok(()) @@ -472,7 +511,7 @@ pub struct CipherSession { } impl CipherSession { - fn new(materials: SessionKeyMaterial, role: Role) -> Self { + pub(crate) fn new(materials: SessionKeyMaterial, role: Role) -> Self { match role { Role::Initiator => { let initiator_length_cipher = FSChaCha20::new(materials.initiator_length_key); @@ -538,7 +577,6 @@ impl CipherSession { #[cfg(all(test, feature = "std"))] mod tests { - use crate::handshake::NUM_INITIAL_HANDSHAKE_BUFFER_BYTES; use super::*; use bitcoin::secp256k1::ellswift::{ElligatorSwift, ElligatorSwiftParty}; @@ -610,6 +648,62 @@ mod tests { assert_eq!(message, plaintext_buffer[1..].to_vec()); // Skip header byte } + #[test] + fn test_decrypt_in_place() { + let alice = + SecretKey::from_str("61062ea5071d800bbfd59e2e8b53d47d194b095ae5a4df04936b49772ef0d4d7") + .unwrap(); + let elliswift_alice = ElligatorSwift::from_str("ec0adff257bbfe500c188c80b4fdd640f6b45a482bbc15fc7cef5931deff0aa186f6eb9bba7b85dc4dcc28b28722de1e3d9108b985e2967045668f66098e475b").unwrap(); + let elliswift_bob = ElligatorSwift::from_str("a4a94dfce69b4a2a0a099313d10f9f7e7d649d60501c9e1d274c300e0d89aafaffffffffffffffffffffffffffffffffffffffffffffffffffffffff8faf88d5").unwrap(); + let session_keys = SessionKeyMaterial::from_ecdh( + elliswift_alice, + elliswift_bob, + alice, + ElligatorSwiftParty::A, + Network::Bitcoin, + ) + .unwrap(); + let mut alice_cipher = CipherSession::new(session_keys.clone(), Role::Initiator); + let mut bob_cipher = CipherSession::new(session_keys, Role::Responder); + + // Test with a genuine packet + let message = b"Test in-place decryption".to_vec(); + let mut enc_packet = vec![0u8; OutboundCipher::encryption_buffer_len(message.len())]; + alice_cipher + .outbound() + .encrypt(&message, &mut enc_packet, PacketType::Genuine, None) + .unwrap(); + + // Decrypt in-place + let mut ciphertext = enc_packet[NUM_LENGTH_BYTES..].to_vec(); + let (packet_type, plaintext) = bob_cipher + .inbound() + .decrypt_in_place(&mut ciphertext, None) + .unwrap(); + + assert_eq!(PacketType::Genuine, packet_type); + assert_eq!(message, plaintext[1..].to_vec()); // Skip header byte + + // Test with a decoy packet and AAD + let message2 = b"Decoy with AAD".to_vec(); + let aad = b"additional authenticated data"; + let mut enc_packet2 = vec![0u8; OutboundCipher::encryption_buffer_len(message2.len())]; + bob_cipher + .outbound() + .encrypt(&message2, &mut enc_packet2, PacketType::Decoy, Some(aad)) + .unwrap(); + + // Decrypt in-place with AAD + let mut ciphertext2 = enc_packet2[NUM_LENGTH_BYTES..].to_vec(); + let (packet_type2, plaintext2) = alice_cipher + .inbound() + .decrypt_in_place(&mut ciphertext2, Some(aad)) + .unwrap(); + + assert_eq!(PacketType::Decoy, packet_type2); + assert_eq!(message2, plaintext2[1..].to_vec()); // Skip header byte + } + #[test] fn test_fuzz_packets() { let mut rng = rand::thread_rng(); @@ -678,7 +772,7 @@ mod tests { } #[test] - fn test_authenticated_garbage() { + fn test_additional_authenticated_data() { let mut rng = rand::thread_rng(); let alice = SecretKey::from_str("61062ea5071d800bbfd59e2e8b53d47d194b095ae5a4df04936b49772ef0d4d7") @@ -724,230 +818,6 @@ mod tests { .unwrap(); } - #[test] - fn test_handshake_with_garbage_and_decoys() { - // Define the garbage and decoys that the initiator is sending to the responder. - let initiator_garbage = vec![1u8, 2u8, 3u8]; - let initiator_decoys: Vec<&[u8]> = vec![&[6u8, 7u8], &[8u8, 0u8]]; - let num_initiator_decoys_bytes = initiator_decoys - .iter() - .map(|slice| slice.len()) - .sum::() - + NUM_PACKET_OVERHEAD_BYTES * initiator_decoys.len(); - let num_initiator_version_bytes = VERSION_CONTENT.len() + NUM_PACKET_OVERHEAD_BYTES; - // Buffer for initiator to write to and responder to read from. - let mut initiator_buffer = vec![ - 0u8; - NUM_ELLIGATOR_SWIFT_BYTES - + initiator_garbage.len() - + NUM_GARBAGE_TERMINTOR_BYTES - + num_initiator_decoys_bytes - + num_initiator_version_bytes - ]; - - // Define the garbage and decoys that the responder is sending to the initiator. - let responder_garbage = vec![4u8, 5u8]; - let responder_decoys: Vec<&[u8]> = vec![&[10u8, 11u8], &[12u8], &[13u8, 14u8, 15u8]]; - let num_responder_decoys_bytes = responder_decoys - .iter() - .map(|slice| slice.len()) - .sum::() - + NUM_PACKET_OVERHEAD_BYTES * responder_decoys.len(); - let num_responder_version_bytes = VERSION_CONTENT.len() + NUM_PACKET_OVERHEAD_BYTES; - // Buffer for responder to write to and initiator to read from. - let mut responder_buffer = vec![ - 0u8; - NUM_ELLIGATOR_SWIFT_BYTES - + responder_garbage.len() - + NUM_GARBAGE_TERMINTOR_BYTES - + num_responder_decoys_bytes - + num_responder_version_bytes - ]; - - // The initiator's handshake writes its 64 byte elligator swift key and garbage to their buffer to send to the responder. - let mut initiator_handshake = Handshake::new( - Network::Bitcoin, - Role::Initiator, - Some(&initiator_garbage), - &mut initiator_buffer[..NUM_ELLIGATOR_SWIFT_BYTES + initiator_garbage.len()], - ) - .unwrap(); - - // The responder also writes its 64 byte elligator swift key and garbage to their buffer to send to the initiator. - let mut responder_handshake = Handshake::new( - Network::Bitcoin, - Role::Responder, - Some(&responder_garbage), - &mut responder_buffer[..NUM_ELLIGATOR_SWIFT_BYTES + responder_garbage.len()], - ) - .unwrap(); - - // The responder has received the initiator's initial material so can complete the secrets. - // With the secrets calculated, the responder can send along the garbage terminator, decoys, and version packet. - responder_handshake - .complete_materials( - initiator_buffer[..NUM_ELLIGATOR_SWIFT_BYTES] - .try_into() - .unwrap(), - &mut responder_buffer[NUM_ELLIGATOR_SWIFT_BYTES + responder_garbage.len()..], - Some(&responder_decoys), - ) - .unwrap(); - - // Once the initiator receives the responder's response it can also complete the secrets. - // The initiator then needs to send along their recently calculated garbage terminator, decoys, and the version packet. - initiator_handshake - .complete_materials( - responder_buffer[..NUM_ELLIGATOR_SWIFT_BYTES] - .try_into() - .unwrap(), - &mut initiator_buffer[NUM_ELLIGATOR_SWIFT_BYTES + initiator_garbage.len()..], - Some(&initiator_decoys), - ) - .unwrap(); - - // The initiator verifies the second half of the responders message which - // includes the garbage, garbage terminator, decoys, and version packet. - let mut packet_buffer = vec![0u8; NUM_INITIAL_HANDSHAKE_BUFFER_BYTES]; - initiator_handshake - .authenticate_garbage_and_version( - &responder_buffer[NUM_ELLIGATOR_SWIFT_BYTES..], - &mut packet_buffer, - ) - .unwrap(); - - // The responder verifies the second message from the initiator which - // includes the garbage, garbage terminator, decoys, and version packet. - responder_handshake - .authenticate_garbage_and_version( - &initiator_buffer[NUM_ELLIGATOR_SWIFT_BYTES..], - &mut packet_buffer, - ) - .unwrap(); - - let mut alice = initiator_handshake.finalize().unwrap(); - let mut bob = responder_handshake.finalize().unwrap(); - - let message = b"Hello world".to_vec(); - let packet_len = OutboundCipher::encryption_buffer_len(message.len()); - let mut encrypted_message_to_alice = vec![0u8; packet_len]; - bob.outbound() - .encrypt( - &message, - &mut encrypted_message_to_alice, - PacketType::Genuine, - None, - ) - .unwrap(); - - let bob_message_len = alice.inbound().decrypt_packet_len( - encrypted_message_to_alice[..NUM_LENGTH_BYTES] - .try_into() - .unwrap(), - ); - let mut dec = vec![0u8; InboundCipher::decryption_buffer_len(bob_message_len)]; - alice - .inbound() - .decrypt( - &encrypted_message_to_alice[NUM_LENGTH_BYTES..NUM_LENGTH_BYTES + bob_message_len], - &mut dec, - None, - ) - .unwrap(); - assert_eq!(message, dec[1..].to_vec()); // Skip header byte - - let message = b"g!".to_vec(); - let packet_len = OutboundCipher::encryption_buffer_len(message.len()); - let mut encrypted_message_to_bob = vec![0u8; packet_len]; - alice - .outbound() - .encrypt( - &message, - &mut encrypted_message_to_bob, - PacketType::Genuine, - None, - ) - .unwrap(); - - let alice_message_len = bob.inbound().decrypt_packet_len( - encrypted_message_to_bob[..NUM_LENGTH_BYTES] - .try_into() - .unwrap(), - ); - let mut dec = vec![0u8; InboundCipher::decryption_buffer_len(alice_message_len)]; - bob.inbound() - .decrypt( - &encrypted_message_to_bob[NUM_LENGTH_BYTES..NUM_LENGTH_BYTES + alice_message_len], - &mut dec, - None, - ) - .unwrap(); - assert_eq!(message, dec[1..].to_vec()); // Skip header byte - } - - #[test] - fn test_partial_decodings() { - let mut rng = rand::thread_rng(); - - let mut init_message = vec![0u8; 64]; - let mut init_handshake = - Handshake::new(Network::Bitcoin, Role::Initiator, None, &mut init_message).unwrap(); - - let mut resp_message = vec![0u8; 100]; - let mut resp_handshake = - Handshake::new(Network::Bitcoin, Role::Responder, None, &mut resp_message).unwrap(); - - resp_handshake - .complete_materials( - init_message.try_into().unwrap(), - &mut resp_message[64..], - None, - ) - .unwrap(); - let mut init_finalize_message = vec![0u8; 36]; - init_handshake - .complete_materials( - resp_message[0..64].try_into().unwrap(), - &mut init_finalize_message, - None, - ) - .unwrap(); - - let mut packet_buffer = vec![0u8; NUM_INITIAL_HANDSHAKE_BUFFER_BYTES]; - init_handshake - .authenticate_garbage_and_version(&resp_message[64..], &mut packet_buffer) - .unwrap(); - resp_handshake - .authenticate_garbage_and_version(&init_finalize_message, &mut packet_buffer) - .unwrap(); - - let mut alice = init_handshake.finalize().unwrap(); - let mut bob = resp_handshake.finalize().unwrap(); - - let mut message_to_bob = Vec::new(); - let message = gen_garbage(420, &mut rng); - let packet_len = OutboundCipher::encryption_buffer_len(message.len()); - let mut enc_packet = vec![0u8; packet_len]; - alice - .outbound() - .encrypt(&message, &mut enc_packet, PacketType::Genuine, None) - .unwrap(); - message_to_bob.extend(&enc_packet); - - let alice_message_len = bob - .inbound() - .decrypt_packet_len(message_to_bob[..3].try_into().unwrap()); - let mut contents = vec![0u8; InboundCipher::decryption_buffer_len(alice_message_len)]; - bob.inbound() - .decrypt( - &message_to_bob[3..3 + alice_message_len], - &mut contents, - None, - ) - .unwrap(); - assert_eq!(message, contents[1..].to_vec()); // Skip header byte - } - // The rest are sourced from [the BIP324 test vectors](https://github.com/bitcoin/bips/blob/master/bip-0324/packet_encoding_test_vectors.csv). #[test] diff --git a/protocol/tests/round_trips.rs b/protocol/tests/round_trips.rs index 86f9352..0fdd7d0 100644 --- a/protocol/tests/round_trips.rs +++ b/protocol/tests/round_trips.rs @@ -5,43 +5,93 @@ const PORT: u16 = 18444; #[test] #[cfg(feature = "std")] fn hello_world_happy_path() { - use bip324::{Handshake, PacketType, Role}; + use bip324::{ + GarbageResult, Handshake, Initialized, PacketType, ReceivedKey, Role, VersionResult, + NUM_LENGTH_BYTES, + }; use bitcoin::Network; - let mut init_message = vec![0u8; 64]; - let mut init_handshake = - Handshake::new(Network::Bitcoin, Role::Initiator, None, &mut init_message).unwrap(); + // Create initiator handshake + let init_handshake = Handshake::::new(Network::Bitcoin, Role::Initiator).unwrap(); - let mut resp_message = vec![0u8; 100]; - let mut resp_handshake = - Handshake::new(Network::Bitcoin, Role::Responder, None, &mut resp_message).unwrap(); + // Send initiator key + let mut init_key_buffer = vec![0u8; Handshake::::send_key_len(None)]; + let init_handshake = init_handshake.send_key(None, &mut init_key_buffer).unwrap(); - resp_handshake - .complete_materials( - init_message.try_into().unwrap(), - &mut resp_message[64..], - None, - ) + // Create responder handshake + let resp_handshake = Handshake::::new(Network::Bitcoin, Role::Responder).unwrap(); + + // Send responder key + let mut resp_key_buffer = vec![0u8; Handshake::::send_key_len(None)]; + let resp_handshake = resp_handshake.send_key(None, &mut resp_key_buffer).unwrap(); + + // Initiator receives responder's key + let init_handshake = init_handshake + .receive_key(resp_key_buffer[..64].try_into().unwrap()) .unwrap(); - let mut init_finalize_message = vec![0u8; 36]; - init_handshake - .complete_materials( - resp_message[0..64].try_into().unwrap(), - &mut init_finalize_message, - None, - ) + + // Responder receives initiator's key + let resp_handshake = resp_handshake + .receive_key(init_key_buffer[..64].try_into().unwrap()) + .unwrap(); + + // Initiator sends version + let mut init_version_buffer = vec![0u8; Handshake::::send_version_len(None)]; + let init_handshake = init_handshake + .send_version(&mut init_version_buffer, None) .unwrap(); - let mut packet_buffer = vec![0u8; 4096]; - init_handshake - .authenticate_garbage_and_version(&resp_message[64..], &mut packet_buffer) + // Responder sends version + let mut resp_version_buffer = vec![0u8; Handshake::::send_version_len(None)]; + let resp_handshake = resp_handshake + .send_version(&mut resp_version_buffer, None) .unwrap(); - resp_handshake - .authenticate_garbage_and_version(&init_finalize_message, &mut packet_buffer) + + // Initiator receives responder's garbage and version + let (mut init_handshake, consumed) = match init_handshake + .receive_garbage(&resp_version_buffer) + .unwrap() + { + GarbageResult::FoundGarbage { + handshake, + consumed_bytes, + } => (handshake, consumed_bytes), + GarbageResult::NeedMoreData(_) => panic!("Should have found garbage"), + }; + + // Process the version packet properly + let remaining = &resp_version_buffer[consumed..]; + let packet_len = init_handshake + .decrypt_packet_len(remaining[..NUM_LENGTH_BYTES].try_into().unwrap()) .unwrap(); + let mut version_packet = remaining[NUM_LENGTH_BYTES..NUM_LENGTH_BYTES + packet_len].to_vec(); + let mut alice = match init_handshake.receive_version(&mut version_packet).unwrap() { + VersionResult::Complete { cipher } => cipher, + VersionResult::Decoy(_) => panic!("Should have completed"), + }; - let mut alice = init_handshake.finalize().unwrap(); - let mut bob = resp_handshake.finalize().unwrap(); + // Responder receives initiator's garbage and version + let (mut resp_handshake, consumed) = match resp_handshake + .receive_garbage(&init_version_buffer) + .unwrap() + { + GarbageResult::FoundGarbage { + handshake, + consumed_bytes, + } => (handshake, consumed_bytes), + GarbageResult::NeedMoreData(_) => panic!("Should have found garbage"), + }; + + // Process the version packet properly + let remaining = &init_version_buffer[consumed..]; + let packet_len = resp_handshake + .decrypt_packet_len(remaining[..NUM_LENGTH_BYTES].try_into().unwrap()) + .unwrap(); + let mut version_packet = remaining[NUM_LENGTH_BYTES..NUM_LENGTH_BYTES + packet_len].to_vec(); + let mut bob = match resp_handshake.receive_version(&mut version_packet).unwrap() { + VersionResult::Complete { cipher } => cipher, + VersionResult::Decoy(_) => panic!("Should have completed"), + }; // Alice and Bob can freely exchange encrypted messages using the packet handler returned by each handshake. let message = b"Hello world".to_vec(); @@ -56,15 +106,17 @@ fn hello_world_happy_path() { ) .unwrap(); - let alice_message_len = alice - .inbound() - .decrypt_packet_len(encrypted_message_to_alice[..3].try_into().unwrap()); + let alice_message_len = alice.inbound().decrypt_packet_len( + encrypted_message_to_alice[..NUM_LENGTH_BYTES] + .try_into() + .unwrap(), + ); let mut decrypted_message = vec![0u8; bip324::InboundCipher::decryption_buffer_len(alice_message_len)]; alice .inbound() .decrypt( - &encrypted_message_to_alice[3..], + &encrypted_message_to_alice[NUM_LENGTH_BYTES..], &mut decrypted_message, None, ) @@ -84,13 +136,19 @@ fn hello_world_happy_path() { ) .unwrap(); - let bob_message_len = bob - .inbound() - .decrypt_packet_len(encrypted_message_to_bob[..3].try_into().unwrap()); + let bob_message_len = bob.inbound().decrypt_packet_len( + encrypted_message_to_bob[..NUM_LENGTH_BYTES] + .try_into() + .unwrap(), + ); let mut decrypted_message = vec![0u8; bip324::InboundCipher::decryption_buffer_len(bob_message_len)]; bob.inbound() - .decrypt(&encrypted_message_to_bob[3..], &mut decrypted_message, None) + .decrypt( + &encrypted_message_to_bob[NUM_LENGTH_BYTES..], + &mut decrypted_message, + None, + ) .unwrap(); assert_eq!(message, decrypted_message[1..].to_vec()); // Skip header byte } @@ -106,48 +164,95 @@ fn regtest_handshake() { use bip324::{ serde::{deserialize, serialize, NetworkMessage}, - Handshake, PacketType, + GarbageResult, Handshake, Initialized, PacketType, ReceivedKey, VersionResult, + NUM_LENGTH_BYTES, }; use bitcoin::p2p::{message_network::VersionMessage, Address, ServiceFlags}; let bitcoind = regtest_process(TransportVersion::V2); let mut stream = TcpStream::connect(bitcoind.params.p2p_socket.unwrap()).unwrap(); - let mut public_key = [0u8; 64]; - let mut handshake = Handshake::new( - bip324::Network::Regtest, - bip324::Role::Initiator, - None, - &mut public_key, - ) - .unwrap(); + + // Initialize handshake + let handshake = + Handshake::::new(bip324::Network::Regtest, bip324::Role::Initiator).unwrap(); + + // Send our public key + let mut public_key = vec![0u8; Handshake::::send_key_len(None)]; + let handshake = handshake.send_key(None, &mut public_key).unwrap(); println!("Writing public key to the remote node"); stream.write_all(&public_key).unwrap(); stream.flush().unwrap(); + + // Read remote public key let mut remote_public_key = [0u8; 64]; println!("Reading the remote node public key"); stream.read_exact(&mut remote_public_key).unwrap(); - let mut local_garbage_terminator_message = [0u8; 36]; + + // Process remote key + let handshake = handshake.receive_key(remote_public_key).unwrap(); + + // Send garbage terminator and version + let mut local_garbage_terminator_message = + vec![0u8; Handshake::::send_version_len(None)]; println!("Sending our garbage terminator"); - handshake - .complete_materials( - remote_public_key, - &mut local_garbage_terminator_message, - None, - ) + let handshake = handshake + .send_version(&mut local_garbage_terminator_message, None) .unwrap(); stream.write_all(&local_garbage_terminator_message).unwrap(); stream.flush().unwrap(); + + // Read and authenticate remote response let mut max_response = [0; 4096]; println!("Reading the response buffer"); let size = stream.read(&mut max_response).unwrap(); - let response = &mut max_response[..size]; + let response = &max_response[..size]; println!("Authenticating the handshake"); - let mut packet_buffer = vec![0u8; 4096]; - handshake - .authenticate_garbage_and_version(response, &mut packet_buffer) - .unwrap(); - println!("Finalizing the handshake"); - let cipher_session = handshake.finalize().unwrap(); + + // First receive garbage + let (mut handshake, consumed) = match handshake.receive_garbage(response).unwrap() { + GarbageResult::FoundGarbage { + handshake, + consumed_bytes, + } => { + println!("Found garbage terminator after {consumed_bytes} bytes"); + (handshake, consumed_bytes) + } + GarbageResult::NeedMoreData(_) => panic!("Should have found garbage"), + }; + + // Then receive version - properly handle packet length and potential decoys. + let mut remaining = &response[consumed..]; + let cipher_session = loop { + // Check if we have enough data for packet length + if remaining.len() < NUM_LENGTH_BYTES { + panic!("Not enough data for packet length"); + } + + let packet_len = handshake + .decrypt_packet_len(remaining[..NUM_LENGTH_BYTES].try_into().unwrap()) + .unwrap(); + + if remaining.len() < NUM_LENGTH_BYTES + packet_len { + panic!("Not enough data for full packet"); + } + + let mut version_packet = + remaining[NUM_LENGTH_BYTES..NUM_LENGTH_BYTES + packet_len].to_vec(); + + match handshake.receive_version(&mut version_packet).unwrap() { + VersionResult::Complete { cipher } => { + println!("Finalizing the handshake"); + break cipher; + } + VersionResult::Decoy(h) => { + println!("Received decoy packet, continuing..."); + handshake = h; + remaining = &remaining[NUM_LENGTH_BYTES + packet_len..]; + continue; + } + } + }; + let (mut decrypter, mut encrypter) = cipher_session.into_split(); let now = SystemTime::now() .duration_since(UNIX_EPOCH) @@ -175,7 +280,7 @@ fn regtest_handshake() { println!("Serializing and writing version message"); stream.write_all(&packet).unwrap(); println!("Reading the response length buffer"); - let mut response_len = [0; 3]; + let mut response_len = [0; NUM_LENGTH_BYTES]; stream.read_exact(&mut response_len).unwrap(); let message_len = decrypter.decrypt_packet_len(response_len); let mut response_message = vec![0; message_len]; @@ -198,18 +303,15 @@ fn regtest_handshake_v1_only() { net::TcpStream, }; - use bip324::Handshake; + use bip324::{Handshake, Initialized}; let bitcoind = regtest_process(TransportVersion::V1); let mut stream = TcpStream::connect(bitcoind.params.p2p_socket.unwrap()).unwrap(); - let mut public_key = [0u8; 64]; - let _ = Handshake::new( - bip324::Network::Regtest, - bip324::Role::Initiator, - None, - &mut public_key, - ) - .unwrap(); + + let handshake = + Handshake::::new(bip324::Network::Regtest, bip324::Role::Initiator).unwrap(); + let mut public_key = vec![0u8; Handshake::::send_key_len(None)]; + let _handshake = handshake.send_key(None, &mut public_key).unwrap(); println!("Writing public key to the remote node"); stream.write_all(&public_key).unwrap(); stream.flush().unwrap();