@@ -20,14 +20,17 @@ use bitcoin::{
2020use rand:: Rng ;
2121
2222use crate :: {
23- CipherSession , Error , InboundCipher , OutboundCipher , PacketType , Role , SessionKeyMaterial ,
23+ CipherSession , Error , OutboundCipher , PacketType , Role , SessionKeyMaterial ,
2424 NUM_ELLIGATOR_SWIFT_BYTES , NUM_GARBAGE_TERMINTOR_BYTES , NUM_LENGTH_BYTES , VERSION_CONTENT ,
2525} ;
2626
27- /// Initial buffer for decoy and version packets in the handshake.
28- /// The buffer may have to be expanded if a party is sending large
29- /// decoy packets.
30- pub const NUM_INITIAL_HANDSHAKE_BUFFER_BYTES : usize = 4096 ;
27+ /// Initial buffer length hint to receive garbage, the garbage terminator,
28+ /// and the version packet from the remote peer. This assumes no decoy packets
29+ /// are sent (which is the default in Bitcoin Core) and no garbage bytes.
30+ /// Calculated as: garbage_terminator (16 bytes) + encrypted_version_packet
31+ /// where the version packet content is currently empty (0 bytes).
32+ pub const NUM_INITIAL_BUFFER_BYTES_HINT : usize =
33+ NUM_GARBAGE_TERMINTOR_BYTES + OutboundCipher :: encryption_buffer_len ( VERSION_CONTENT . len ( ) ) ;
3134// Maximum number of garbage bytes before the terminator.
3235const MAX_NUM_GARBAGE_BYTES : usize = 4095 ;
3336
@@ -60,6 +63,8 @@ pub struct SentVersion {
6063 cipher : CipherSession ,
6164 remote_garbage_terminator : [ u8 ; NUM_GARBAGE_TERMINTOR_BYTES ] ,
6265 bytes_written : usize ,
66+ ciphertext_index : usize ,
67+ remote_garbage_authenticated : bool ,
6368}
6469
6570/// Success variants for receive_version.
@@ -147,7 +152,22 @@ impl<'a> Handshake<'a, Initialized> {
147152 NUM_ELLIGATOR_SWIFT_BYTES + garbage. map ( |g| g. len ( ) ) . unwrap_or ( 0 )
148153 }
149154
150- /// Send local public key and optional garbage.
155+ /// Send local public key and optional garbage to initiate the handshake.
156+ ///
157+ /// # Parameters
158+ ///
159+ /// * `garbage` - Optional garbage bytes to append after the public key. Limited to 4095 bytes.
160+ /// * `output_buffer` - Buffer to write the key and garbage. Must have sufficient capacity
161+ /// as calculated by `send_key_len()`.
162+ ///
163+ /// # Returns
164+ ///
165+ /// `Ok(Handshake<SentKey>)` - Ready to receive remote peer's key material.
166+ ///
167+ /// # Errors
168+ ///
169+ /// * `TooMuchGarbage` - Garbage exceeds 4095 byte limit.
170+ /// * `BufferTooSmall` - Output buffer insufficient for key + garbage.
151171 pub fn send_key (
152172 mut self ,
153173 garbage : Option < & ' a [ u8 ] > ,
@@ -200,7 +220,25 @@ impl<'a> Handshake<'a, SentKey> {
200220 self . state . bytes_written
201221 }
202222
203- /// Process received key material.
223+ /// Process the remote peer's public key and derive shared session secrets.
224+ ///
225+ /// This is the **second state transition** in the handshake process, moving from
226+ /// `SentKey` to `ReceivedKey` state. The method performs ECDH key exchange using
227+ /// the received remote public key and generates all cryptographic material needed
228+ /// for the secure session.
229+ ///
230+ /// # Parameters
231+ ///
232+ /// * `their_key` - The remote peer's 64-byte ElligatorSwift encoded public key.
233+ ///
234+ /// # Returns
235+ ///
236+ /// `Ok(Handshake<ReceivedKey>)` - Ready to send version packet with derived session keys.
237+ ///
238+ /// # Errors
239+ ///
240+ /// * `V1Protocol` - Remote peer is using the legacy V1 protocol.
241+ /// * `SecretGeneration` - Failed to derive session keys from ECDH.
204242 pub fn receive_key (
205243 self ,
206244 their_key : [ u8 ; NUM_ELLIGATOR_SWIFT_BYTES ] ,
@@ -268,7 +306,27 @@ impl<'a> Handshake<'a, ReceivedKey> {
268306 len
269307 }
270308
271- /// Send garbage terminator, optional decoys, and version packet.
309+ /// Send garbage terminator, optional decoy packets, and version packet.
310+ ///
311+ /// This is the **third state transition** in the handshake process, moving from
312+ /// `ReceivedKey` to `SentVersion` state. The method initiates encrypted communication
313+ /// by sending the local garbage terminator followed by encrypted packets.
314+ ///
315+ /// # Parameters
316+ ///
317+ /// * `output_buffer` - Buffer to write terminator and encrypted packets. Must have
318+ /// sufficient capacity as calculated by `send_version_len()`.
319+ /// * `decoys` - Optional array of decoy packet contents to send before version packet
320+ /// to help hide the shape of traffic.
321+ ///
322+ /// # Returns
323+ ///
324+ /// `Ok(Handshake<SentVersion>)` - Ready to receive and authenticate remote peer's version.
325+ ///
326+ /// # Errors
327+ ///
328+ /// * `BufferTooSmall` - Output buffer insufficient for terminator + packets.
329+ /// * `Decryption` - Cipher operation failed.
272330 pub fn send_version (
273331 self ,
274332 output_buffer : & mut [ u8 ] ,
@@ -334,6 +392,8 @@ impl<'a> Handshake<'a, ReceivedKey> {
334392 cipher,
335393 remote_garbage_terminator,
336394 bytes_written,
395+ ciphertext_index : 0 ,
396+ remote_garbage_authenticated : false ,
337397 } ,
338398 } )
339399 }
@@ -346,11 +406,30 @@ impl<'a> Handshake<'a, SentVersion> {
346406 self . state . bytes_written
347407 }
348408
349- /// Authenticate garbage and version packet.
409+ /// Authenticate remote peer's garbage, decoy packets, and version packet.
410+ ///
411+ /// This method is unique in the handshake process as it requires a **mutable** input buffer
412+ /// to perform in-place decryption operations. The buffer contains everything after the 64
413+ /// byte public key received from the remote peer: optional garbage bytes, garbage terminator,
414+ /// and encrypted packets (decoys and final version packet).
415+ ///
416+ /// The input buffer is mutable in the case because the caller generally doesn't care
417+ /// about the decoy and version packets, and definitely doesn't want to deal with
418+ /// allocating memory for them.
419+ ///
420+ /// # Parameters
421+ ///
422+ /// * `input_buffer` - **Mutable** buffer containing garbage + terminator + encrypted packets.
423+ /// The buffer will be modified during in-place decryption operations.
424+ ///
425+ /// # Returns
426+ ///
427+ /// * `Complete { cipher, bytes_consumed }` - Handshake succeeded, secure session established.
428+ /// * `NeedMoreData(handshake)` - Insufficient data, retry by extending the buffer.
429+ /// ```
350430 pub fn receive_version (
351431 mut self ,
352- input_buffer : & [ u8 ] ,
353- output_buffer : & mut [ u8 ] ,
432+ input_buffer : & mut [ u8 ] ,
354433 ) -> Result < HandshakeAuthentication < ' a > , Error > {
355434 let ( garbage, ciphertext) = match self . split_garbage ( input_buffer) {
356435 Ok ( split) => split,
@@ -360,22 +439,25 @@ impl<'a> Handshake<'a, SentVersion> {
360439 Err ( e) => return Err ( e) ,
361440 } ;
362441
363- let mut aad = if garbage. is_empty ( ) {
442+ let mut aad = if garbage. is_empty ( ) || self . state . remote_garbage_authenticated {
364443 None
365444 } else {
366445 Some ( garbage)
367446 } ;
368- let mut ciphertext_index = 0 ;
369- let mut found_version = false ;
370447
371448 // First packet authenticates remote garbage.
372449 // Continue through decoys until we find version packet.
373- while !found_version {
374- match self . decrypt_packet ( & ciphertext[ ciphertext_index..] , output_buffer , aad) {
450+ loop {
451+ match self . decrypt_packet ( & mut ciphertext[ self . state . ciphertext_index ..] , aad) {
375452 Ok ( ( packet_type, bytes_consumed) ) => {
376- aad = None ;
377- ciphertext_index += bytes_consumed;
378- found_version = matches ! ( packet_type, PacketType :: Genuine ) ;
453+ if aad. is_some ( ) {
454+ aad = None ;
455+ self . state . remote_garbage_authenticated = true ;
456+ }
457+ self . state . ciphertext_index += bytes_consumed;
458+ if matches ! ( packet_type, PacketType :: Genuine ) {
459+ break ;
460+ }
379461 }
380462 Err ( Error :: CiphertextTooSmall ) => {
381463 return Ok ( HandshakeAuthentication :: NeedMoreData ( self ) )
@@ -384,8 +466,9 @@ impl<'a> Handshake<'a, SentVersion> {
384466 }
385467 }
386468
387- // Calculate total bytes consumed (garbage + terminator + packets)
388- let bytes_consumed = garbage. len ( ) + NUM_GARBAGE_TERMINTOR_BYTES + ciphertext_index;
469+ // Calculate total bytes consumed.
470+ let bytes_consumed =
471+ garbage. len ( ) + NUM_GARBAGE_TERMINTOR_BYTES + self . state . ciphertext_index ;
389472
390473 Ok ( HandshakeAuthentication :: Complete {
391474 cipher : self . state . cipher ,
@@ -394,25 +477,27 @@ impl<'a> Handshake<'a, SentVersion> {
394477 }
395478
396479 /// Split buffer on garbage terminator.
397- fn split_garbage < ' b > ( & self , buffer : & ' b [ u8 ] ) -> Result < ( & ' b [ u8 ] , & ' b [ u8 ] ) , Error > {
480+ fn split_garbage < ' b > ( & self , buffer : & ' b mut [ u8 ] ) -> Result < ( & ' b [ u8 ] , & ' b mut [ u8 ] ) , Error > {
398481 let terminator = & self . state . remote_garbage_terminator ;
399482
400483 if let Some ( index) = buffer
401484 . windows ( terminator. len ( ) )
402485 . position ( |window| window == terminator)
403486 {
404- Ok ( ( & buffer[ ..index] , & buffer[ index + terminator. len ( ) ..] ) )
487+ let ( garbage, rest) = buffer. split_at_mut ( index) ;
488+ let ciphertext = & mut rest[ terminator. len ( ) ..] ;
489+ Ok ( ( garbage, ciphertext) )
405490 } else if buffer. len ( ) >= MAX_NUM_GARBAGE_BYTES + NUM_GARBAGE_TERMINTOR_BYTES {
406491 Err ( Error :: NoGarbageTerminator )
407492 } else {
408493 Err ( Error :: CiphertextTooSmall )
409494 }
410495 }
411496
497+ /// Decrypts in place, returns the packet type and number of bytes consumed from the ciphertext.
412498 fn decrypt_packet (
413499 & mut self ,
414- ciphertext : & [ u8 ] ,
415- output_buffer : & mut [ u8 ] ,
500+ ciphertext : & mut [ u8 ] ,
416501 aad : Option < & [ u8 ] > ,
417502 ) -> Result < ( PacketType , usize ) , Error > {
418503 if ciphertext. len ( ) < NUM_LENGTH_BYTES {
@@ -429,17 +514,8 @@ impl<'a> Handshake<'a, SentVersion> {
429514 return Err ( Error :: CiphertextTooSmall ) ;
430515 }
431516
432- // Check output buffer is large enough.
433- let plaintext_len = InboundCipher :: decryption_buffer_len ( packet_len) ;
434- if output_buffer. len ( ) < plaintext_len {
435- return Err ( Error :: BufferTooSmall {
436- required_bytes : plaintext_len,
437- } ) ;
438- }
439-
440- let packet_type = self . state . cipher . inbound ( ) . decrypt (
441- & ciphertext[ NUM_LENGTH_BYTES ..NUM_LENGTH_BYTES + packet_len] ,
442- & mut output_buffer[ ..plaintext_len] ,
517+ let ( packet_type, _) = self . state . cipher . inbound ( ) . decrypt_in_place (
518+ & mut ciphertext[ NUM_LENGTH_BYTES ..NUM_LENGTH_BYTES + packet_len] ,
443519 aad,
444520 ) ?;
445521
@@ -512,12 +588,10 @@ mod tests {
512588 . send_version ( & mut resp_version_buffer, Some ( & resp_decoys) )
513589 . unwrap ( ) ;
514590
515- let mut packet_buffer = vec ! [ 0u8 ; NUM_INITIAL_HANDSHAKE_BUFFER_BYTES ] ;
516-
517591 // Initiator receives responder's garbage, decoys, and version.
518- let full_resp_message = [ & responder_garbage[ ..] , & resp_version_buffer[ ..] ] . concat ( ) ;
592+ let mut full_resp_message = [ & responder_garbage[ ..] , & resp_version_buffer[ ..] ] . concat ( ) ;
519593 match init_handshake
520- . receive_version ( & full_resp_message , & mut packet_buffer )
594+ . receive_version ( & mut full_resp_message )
521595 . unwrap ( )
522596 {
523597 HandshakeAuthentication :: Complete { bytes_consumed, .. } => {
@@ -527,9 +601,9 @@ mod tests {
527601 }
528602
529603 // Responder receives initiator's garbage, decoys, and version.
530- let full_init_message = [ & initiator_garbage[ ..] , & init_version_buffer[ ..] ] . concat ( ) ;
604+ let mut full_init_message = [ & initiator_garbage[ ..] , & init_version_buffer[ ..] ] . concat ( ) ;
531605 match resp_handshake
532- . receive_version ( & full_init_message , & mut packet_buffer )
606+ . receive_version ( & mut full_init_message )
533607 . unwrap ( )
534608 {
535609 HandshakeAuthentication :: Complete { bytes_consumed, .. } => {
@@ -601,13 +675,10 @@ mod tests {
601675 . send_version ( & mut resp_version_buffer, None )
602676 . unwrap ( ) ;
603677
604- let mut packet_buffer = vec ! [ 0u8 ; NUM_INITIAL_HANDSHAKE_BUFFER_BYTES ] ;
605-
606678 // Feed data in very small chunks to trigger NeedMoreData.
607679 let partial_data_1 = & init_version_buffer[ ..1 ] ;
608- let returned_handshake = match resp_handshake
609- . receive_version ( partial_data_1, & mut packet_buffer)
610- . unwrap ( )
680+ let mut partial_data_1 = partial_data_1. to_vec ( ) ;
681+ let returned_handshake = match resp_handshake. receive_version ( & mut partial_data_1) . unwrap ( )
611682 {
612683 HandshakeAuthentication :: NeedMoreData ( handshake) => handshake,
613684 HandshakeAuthentication :: Complete { .. } => {
@@ -617,8 +688,9 @@ mod tests {
617688
618689 // Feed a bit more data - still probably not enough.
619690 let partial_data_2 = & init_version_buffer[ ..5 ] ;
691+ let mut partial_data_2 = partial_data_2. to_vec ( ) ;
620692 let returned_handshake = match returned_handshake
621- . receive_version ( partial_data_2 , & mut packet_buffer )
693+ . receive_version ( & mut partial_data_2 )
622694 . unwrap ( )
623695 {
624696 HandshakeAuthentication :: NeedMoreData ( handshake) => handshake,
@@ -628,10 +700,8 @@ mod tests {
628700 } ;
629701
630702 // Now provide the complete data.
631- match returned_handshake
632- . receive_version ( & init_version_buffer, & mut packet_buffer)
633- . unwrap ( )
634- {
703+ let mut full_data = init_version_buffer. clone ( ) ;
704+ match returned_handshake. receive_version ( & mut full_data) . unwrap ( ) {
635705 HandshakeAuthentication :: Complete { bytes_consumed, .. } => {
636706 assert_eq ! ( bytes_consumed, init_version_buffer. len( ) ) ;
637707 }
@@ -657,13 +727,13 @@ mod tests {
657727 let handshake = handshake. send_version ( & mut version_buffer, None ) . unwrap ( ) ;
658728
659729 // Test with a buffer that is too long (should fail to find terminator)
660- let test_buffer = vec ! [ 0 ; MAX_NUM_GARBAGE_BYTES + NUM_GARBAGE_TERMINTOR_BYTES ] ;
661- let result = handshake. split_garbage ( & test_buffer) ;
730+ let mut test_buffer = vec ! [ 0 ; MAX_NUM_GARBAGE_BYTES + NUM_GARBAGE_TERMINTOR_BYTES ] ;
731+ let result = handshake. split_garbage ( & mut test_buffer) ;
662732 assert ! ( matches!( result, Err ( Error :: NoGarbageTerminator ) ) ) ;
663733
664734 // Test with a buffer that's just short of the required length
665- let short_buffer = vec ! [ 0 ; MAX_NUM_GARBAGE_BYTES + NUM_GARBAGE_TERMINTOR_BYTES - 1 ] ;
666- let result = handshake. split_garbage ( & short_buffer) ;
735+ let mut short_buffer = vec ! [ 0 ; MAX_NUM_GARBAGE_BYTES + NUM_GARBAGE_TERMINTOR_BYTES - 1 ] ;
736+ let result = handshake. split_garbage ( & mut short_buffer) ;
667737 assert ! ( matches!( result, Err ( Error :: CiphertextTooSmall ) ) ) ;
668738 }
669739}
0 commit comments