Skip to content

Commit 386828f

Browse files
committed
feat!: decrypt in place for handshake
Update the receive_version to take a mutable input buffer instead of an immutable input and a mutable output buffer. During the handshake, callers don't care about decoy packets or version packet. And they really don't want to deal with the variable sized output buffer.
1 parent 6fafc1f commit 386828f

6 files changed

Lines changed: 170 additions & 115 deletions

File tree

protocol/benches/cipher_session.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ extern crate test;
66

77
use bip324::{
88
CipherSession, Handshake, HandshakeAuthentication, InboundCipher, Initialized, Network,
9-
OutboundCipher, PacketType, ReceivedKey, Role, NUM_INITIAL_HANDSHAKE_BUFFER_BYTES,
9+
OutboundCipher, PacketType, ReceivedKey, Role,
1010
};
1111
use test::{black_box, Bencher};
1212

@@ -45,11 +45,9 @@ fn create_cipher_session_pair() -> (CipherSession, CipherSession) {
4545
.send_version(&mut bob_version_buffer, None)
4646
.unwrap();
4747

48-
let mut packet_buffer = vec![0u8; NUM_INITIAL_HANDSHAKE_BUFFER_BYTES];
49-
5048
// Alice receives Bob's version.
5149
let alice = match alice_handshake
52-
.receive_version(&bob_version_buffer, &mut packet_buffer)
50+
.receive_version(&mut bob_version_buffer)
5351
.unwrap()
5452
{
5553
HandshakeAuthentication::Complete { cipher, .. } => cipher,
@@ -58,7 +56,7 @@ fn create_cipher_session_pair() -> (CipherSession, CipherSession) {
5856

5957
// Bob receives Alice's version.
6058
let bob = match bob_handshake
61-
.receive_version(&alice_version_buffer, &mut packet_buffer)
59+
.receive_version(&mut alice_version_buffer)
6260
.unwrap()
6361
{
6462
HandshakeAuthentication::Complete { cipher, .. } => cipher,

protocol/fuzz/fuzz_targets/handshake.rs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,7 @@
77
//! * The implementation should handle all inputs gracefully.
88
99
#![no_main]
10-
use bip324::{
11-
Handshake, HandshakeAuthentication, Initialized, Network, ReceivedKey, Role,
12-
NUM_INITIAL_HANDSHAKE_BUFFER_BYTES,
13-
};
10+
use bip324::{Handshake, HandshakeAuthentication, Initialized, Network, ReceivedKey, Role};
1411
use libfuzzer_sys::fuzz_target;
1512

1613
fuzz_target!(|data: &[u8]| {
@@ -41,9 +38,8 @@ fuzz_target!(|data: &[u8]| {
4138
let handshake = handshake.send_version(&mut version_buffer, None).unwrap();
4239

4340
// Try to receive and authenticate the fuzzed garbage and version data.
44-
let garbage_and_version = Vec::from(&data[64..]);
45-
let mut packet_buffer = vec![0u8; NUM_INITIAL_HANDSHAKE_BUFFER_BYTES];
46-
match handshake.receive_version(&garbage_and_version, &mut packet_buffer) {
41+
let mut garbage_and_version = Vec::from(&data[64..]);
42+
match handshake.receive_version(&mut garbage_and_version) {
4743
Ok(HandshakeAuthentication::Complete { .. }) => {
4844
// Handshake completed successfully.
4945
// This should only happen with some very lucky random bytes.

protocol/src/handshake.rs

Lines changed: 126 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,17 @@ use bitcoin::{
2020
use rand::Rng;
2121

2222
use crate::{
23-
CipherSession, Error, InboundCipher, OutboundCipher, PacketType, Role, SessionKeyMaterial,
23+
CipherSession, Error, OutboundCipher, PacketType, Role, SessionKeyMaterial,
2424
NUM_ELLIGATOR_SWIFT_BYTES, NUM_GARBAGE_TERMINTOR_BYTES, NUM_LENGTH_BYTES, VERSION_CONTENT,
2525
};
2626

27-
/// Initial buffer for decoy and version packets in the handshake.
28-
/// The buffer may have to be expanded if a party is sending large
29-
/// decoy packets.
30-
pub const NUM_INITIAL_HANDSHAKE_BUFFER_BYTES: usize = 4096;
27+
/// Initial buffer length hint to receive garbage, the garbage terminator,
28+
/// and the version packet from the remote peer. This assumes no decoy packets
29+
/// are sent (which is the default in Bitcoin Core) and no garbage bytes.
30+
/// Calculated as: garbage_terminator (16 bytes) + encrypted_version_packet
31+
/// where the version packet content is currently empty (0 bytes).
32+
pub const NUM_INITIAL_BUFFER_BYTES_HINT: usize =
33+
NUM_GARBAGE_TERMINTOR_BYTES + OutboundCipher::encryption_buffer_len(VERSION_CONTENT.len());
3134
// Maximum number of garbage bytes before the terminator.
3235
const MAX_NUM_GARBAGE_BYTES: usize = 4095;
3336

@@ -60,6 +63,8 @@ pub struct SentVersion {
6063
cipher: CipherSession,
6164
remote_garbage_terminator: [u8; NUM_GARBAGE_TERMINTOR_BYTES],
6265
bytes_written: usize,
66+
ciphertext_index: usize,
67+
remote_garbage_authenticated: bool,
6368
}
6469

6570
/// Success variants for receive_version.
@@ -147,7 +152,22 @@ impl<'a> Handshake<'a, Initialized> {
147152
NUM_ELLIGATOR_SWIFT_BYTES + garbage.map(|g| g.len()).unwrap_or(0)
148153
}
149154

150-
/// Send local public key and optional garbage.
155+
/// Send local public key and optional garbage to initiate the handshake.
156+
///
157+
/// # Parameters
158+
///
159+
/// * `garbage` - Optional garbage bytes to append after the public key. Limited to 4095 bytes.
160+
/// * `output_buffer` - Buffer to write the key and garbage. Must have sufficient capacity
161+
/// as calculated by `send_key_len()`.
162+
///
163+
/// # Returns
164+
///
165+
/// `Ok(Handshake<SentKey>)` - Ready to receive remote peer's key material.
166+
///
167+
/// # Errors
168+
///
169+
/// * `TooMuchGarbage` - Garbage exceeds 4095 byte limit.
170+
/// * `BufferTooSmall` - Output buffer insufficient for key + garbage.
151171
pub fn send_key(
152172
mut self,
153173
garbage: Option<&'a [u8]>,
@@ -200,7 +220,25 @@ impl<'a> Handshake<'a, SentKey> {
200220
self.state.bytes_written
201221
}
202222

203-
/// Process received key material.
223+
/// Process the remote peer's public key and derive shared session secrets.
224+
///
225+
/// This is the **second state transition** in the handshake process, moving from
226+
/// `SentKey` to `ReceivedKey` state. The method performs ECDH key exchange using
227+
/// the received remote public key and generates all cryptographic material needed
228+
/// for the secure session.
229+
///
230+
/// # Parameters
231+
///
232+
/// * `their_key` - The remote peer's 64-byte ElligatorSwift encoded public key.
233+
///
234+
/// # Returns
235+
///
236+
/// `Ok(Handshake<ReceivedKey>)` - Ready to send version packet with derived session keys.
237+
///
238+
/// # Errors
239+
///
240+
/// * `V1Protocol` - Remote peer is using the legacy V1 protocol.
241+
/// * `SecretGeneration` - Failed to derive session keys from ECDH.
204242
pub fn receive_key(
205243
self,
206244
their_key: [u8; NUM_ELLIGATOR_SWIFT_BYTES],
@@ -268,7 +306,27 @@ impl<'a> Handshake<'a, ReceivedKey> {
268306
len
269307
}
270308

271-
/// Send garbage terminator, optional decoys, and version packet.
309+
/// Send garbage terminator, optional decoy packets, and version packet.
310+
///
311+
/// This is the **third state transition** in the handshake process, moving from
312+
/// `ReceivedKey` to `SentVersion` state. The method initiates encrypted communication
313+
/// by sending the local garbage terminator followed by encrypted packets.
314+
///
315+
/// # Parameters
316+
///
317+
/// * `output_buffer` - Buffer to write terminator and encrypted packets. Must have
318+
/// sufficient capacity as calculated by `send_version_len()`.
319+
/// * `decoys` - Optional array of decoy packet contents to send before version packet
320+
/// to help hide the shape of traffic.
321+
///
322+
/// # Returns
323+
///
324+
/// `Ok(Handshake<SentVersion>)` - Ready to receive and authenticate remote peer's version.
325+
///
326+
/// # Errors
327+
///
328+
/// * `BufferTooSmall` - Output buffer insufficient for terminator + packets.
329+
/// * `Decryption` - Cipher operation failed.
272330
pub fn send_version(
273331
self,
274332
output_buffer: &mut [u8],
@@ -334,6 +392,8 @@ impl<'a> Handshake<'a, ReceivedKey> {
334392
cipher,
335393
remote_garbage_terminator,
336394
bytes_written,
395+
ciphertext_index: 0,
396+
remote_garbage_authenticated: false,
337397
},
338398
})
339399
}
@@ -346,11 +406,30 @@ impl<'a> Handshake<'a, SentVersion> {
346406
self.state.bytes_written
347407
}
348408

349-
/// Authenticate garbage and version packet.
409+
/// Authenticate remote peer's garbage, decoy packets, and version packet.
410+
///
411+
/// This method is unique in the handshake process as it requires a **mutable** input buffer
412+
/// to perform in-place decryption operations. The buffer contains everything after the 64
413+
/// byte public key received from the remote peer: optional garbage bytes, garbage terminator,
414+
/// and encrypted packets (decoys and final version packet).
415+
///
416+
/// The input buffer is mutable in the case because the caller generally doesn't care
417+
/// about the decoy and version packets, and definitely doesn't want to deal with
418+
/// allocating memory for them.
419+
///
420+
/// # Parameters
421+
///
422+
/// * `input_buffer` - **Mutable** buffer containing garbage + terminator + encrypted packets.
423+
/// The buffer will be modified during in-place decryption operations.
424+
///
425+
/// # Returns
426+
///
427+
/// * `Complete { cipher, bytes_consumed }` - Handshake succeeded, secure session established.
428+
/// * `NeedMoreData(handshake)` - Insufficient data, retry by extending the buffer.
429+
/// ```
350430
pub fn receive_version(
351431
mut self,
352-
input_buffer: &[u8],
353-
output_buffer: &mut [u8],
432+
input_buffer: &mut [u8],
354433
) -> Result<HandshakeAuthentication<'a>, Error> {
355434
let (garbage, ciphertext) = match self.split_garbage(input_buffer) {
356435
Ok(split) => split,
@@ -360,22 +439,25 @@ impl<'a> Handshake<'a, SentVersion> {
360439
Err(e) => return Err(e),
361440
};
362441

363-
let mut aad = if garbage.is_empty() {
442+
let mut aad = if garbage.is_empty() || self.state.remote_garbage_authenticated {
364443
None
365444
} else {
366445
Some(garbage)
367446
};
368-
let mut ciphertext_index = 0;
369-
let mut found_version = false;
370447

371448
// First packet authenticates remote garbage.
372449
// Continue through decoys until we find version packet.
373-
while !found_version {
374-
match self.decrypt_packet(&ciphertext[ciphertext_index..], output_buffer, aad) {
450+
loop {
451+
match self.decrypt_packet(&mut ciphertext[self.state.ciphertext_index..], aad) {
375452
Ok((packet_type, bytes_consumed)) => {
376-
aad = None;
377-
ciphertext_index += bytes_consumed;
378-
found_version = matches!(packet_type, PacketType::Genuine);
453+
if aad.is_some() {
454+
aad = None;
455+
self.state.remote_garbage_authenticated = true;
456+
}
457+
self.state.ciphertext_index += bytes_consumed;
458+
if matches!(packet_type, PacketType::Genuine) {
459+
break;
460+
}
379461
}
380462
Err(Error::CiphertextTooSmall) => {
381463
return Ok(HandshakeAuthentication::NeedMoreData(self))
@@ -384,8 +466,9 @@ impl<'a> Handshake<'a, SentVersion> {
384466
}
385467
}
386468

387-
// Calculate total bytes consumed (garbage + terminator + packets)
388-
let bytes_consumed = garbage.len() + NUM_GARBAGE_TERMINTOR_BYTES + ciphertext_index;
469+
// Calculate total bytes consumed.
470+
let bytes_consumed =
471+
garbage.len() + NUM_GARBAGE_TERMINTOR_BYTES + self.state.ciphertext_index;
389472

390473
Ok(HandshakeAuthentication::Complete {
391474
cipher: self.state.cipher,
@@ -394,25 +477,27 @@ impl<'a> Handshake<'a, SentVersion> {
394477
}
395478

396479
/// Split buffer on garbage terminator.
397-
fn split_garbage<'b>(&self, buffer: &'b [u8]) -> Result<(&'b [u8], &'b [u8]), Error> {
480+
fn split_garbage<'b>(&self, buffer: &'b mut [u8]) -> Result<(&'b [u8], &'b mut [u8]), Error> {
398481
let terminator = &self.state.remote_garbage_terminator;
399482

400483
if let Some(index) = buffer
401484
.windows(terminator.len())
402485
.position(|window| window == terminator)
403486
{
404-
Ok((&buffer[..index], &buffer[index + terminator.len()..]))
487+
let (garbage, rest) = buffer.split_at_mut(index);
488+
let ciphertext = &mut rest[terminator.len()..];
489+
Ok((garbage, ciphertext))
405490
} else if buffer.len() >= MAX_NUM_GARBAGE_BYTES + NUM_GARBAGE_TERMINTOR_BYTES {
406491
Err(Error::NoGarbageTerminator)
407492
} else {
408493
Err(Error::CiphertextTooSmall)
409494
}
410495
}
411496

497+
/// Decrypts in place, returns the packet type and number of bytes consumed from the ciphertext.
412498
fn decrypt_packet(
413499
&mut self,
414-
ciphertext: &[u8],
415-
output_buffer: &mut [u8],
500+
ciphertext: &mut [u8],
416501
aad: Option<&[u8]>,
417502
) -> Result<(PacketType, usize), Error> {
418503
if ciphertext.len() < NUM_LENGTH_BYTES {
@@ -429,17 +514,8 @@ impl<'a> Handshake<'a, SentVersion> {
429514
return Err(Error::CiphertextTooSmall);
430515
}
431516

432-
// Check output buffer is large enough.
433-
let plaintext_len = InboundCipher::decryption_buffer_len(packet_len);
434-
if output_buffer.len() < plaintext_len {
435-
return Err(Error::BufferTooSmall {
436-
required_bytes: plaintext_len,
437-
});
438-
}
439-
440-
let packet_type = self.state.cipher.inbound().decrypt(
441-
&ciphertext[NUM_LENGTH_BYTES..NUM_LENGTH_BYTES + packet_len],
442-
&mut output_buffer[..plaintext_len],
517+
let (packet_type, _) = self.state.cipher.inbound().decrypt_in_place(
518+
&mut ciphertext[NUM_LENGTH_BYTES..NUM_LENGTH_BYTES + packet_len],
443519
aad,
444520
)?;
445521

@@ -512,12 +588,10 @@ mod tests {
512588
.send_version(&mut resp_version_buffer, Some(&resp_decoys))
513589
.unwrap();
514590

515-
let mut packet_buffer = vec![0u8; NUM_INITIAL_HANDSHAKE_BUFFER_BYTES];
516-
517591
// Initiator receives responder's garbage, decoys, and version.
518-
let full_resp_message = [&responder_garbage[..], &resp_version_buffer[..]].concat();
592+
let mut full_resp_message = [&responder_garbage[..], &resp_version_buffer[..]].concat();
519593
match init_handshake
520-
.receive_version(&full_resp_message, &mut packet_buffer)
594+
.receive_version(&mut full_resp_message)
521595
.unwrap()
522596
{
523597
HandshakeAuthentication::Complete { bytes_consumed, .. } => {
@@ -527,9 +601,9 @@ mod tests {
527601
}
528602

529603
// Responder receives initiator's garbage, decoys, and version.
530-
let full_init_message = [&initiator_garbage[..], &init_version_buffer[..]].concat();
604+
let mut full_init_message = [&initiator_garbage[..], &init_version_buffer[..]].concat();
531605
match resp_handshake
532-
.receive_version(&full_init_message, &mut packet_buffer)
606+
.receive_version(&mut full_init_message)
533607
.unwrap()
534608
{
535609
HandshakeAuthentication::Complete { bytes_consumed, .. } => {
@@ -601,13 +675,10 @@ mod tests {
601675
.send_version(&mut resp_version_buffer, None)
602676
.unwrap();
603677

604-
let mut packet_buffer = vec![0u8; NUM_INITIAL_HANDSHAKE_BUFFER_BYTES];
605-
606678
// Feed data in very small chunks to trigger NeedMoreData.
607679
let partial_data_1 = &init_version_buffer[..1];
608-
let returned_handshake = match resp_handshake
609-
.receive_version(partial_data_1, &mut packet_buffer)
610-
.unwrap()
680+
let mut partial_data_1 = partial_data_1.to_vec();
681+
let returned_handshake = match resp_handshake.receive_version(&mut partial_data_1).unwrap()
611682
{
612683
HandshakeAuthentication::NeedMoreData(handshake) => handshake,
613684
HandshakeAuthentication::Complete { .. } => {
@@ -617,8 +688,9 @@ mod tests {
617688

618689
// Feed a bit more data - still probably not enough.
619690
let partial_data_2 = &init_version_buffer[..5];
691+
let mut partial_data_2 = partial_data_2.to_vec();
620692
let returned_handshake = match returned_handshake
621-
.receive_version(partial_data_2, &mut packet_buffer)
693+
.receive_version(&mut partial_data_2)
622694
.unwrap()
623695
{
624696
HandshakeAuthentication::NeedMoreData(handshake) => handshake,
@@ -628,10 +700,8 @@ mod tests {
628700
};
629701

630702
// Now provide the complete data.
631-
match returned_handshake
632-
.receive_version(&init_version_buffer, &mut packet_buffer)
633-
.unwrap()
634-
{
703+
let mut full_data = init_version_buffer.clone();
704+
match returned_handshake.receive_version(&mut full_data).unwrap() {
635705
HandshakeAuthentication::Complete { bytes_consumed, .. } => {
636706
assert_eq!(bytes_consumed, init_version_buffer.len());
637707
}
@@ -657,13 +727,13 @@ mod tests {
657727
let handshake = handshake.send_version(&mut version_buffer, None).unwrap();
658728

659729
// Test with a buffer that is too long (should fail to find terminator)
660-
let test_buffer = vec![0; MAX_NUM_GARBAGE_BYTES + NUM_GARBAGE_TERMINTOR_BYTES];
661-
let result = handshake.split_garbage(&test_buffer);
730+
let mut test_buffer = vec![0; MAX_NUM_GARBAGE_BYTES + NUM_GARBAGE_TERMINTOR_BYTES];
731+
let result = handshake.split_garbage(&mut test_buffer);
662732
assert!(matches!(result, Err(Error::NoGarbageTerminator)));
663733

664734
// Test with a buffer that's just short of the required length
665-
let short_buffer = vec![0; MAX_NUM_GARBAGE_BYTES + NUM_GARBAGE_TERMINTOR_BYTES - 1];
666-
let result = handshake.split_garbage(&short_buffer);
735+
let mut short_buffer = vec![0; MAX_NUM_GARBAGE_BYTES + NUM_GARBAGE_TERMINTOR_BYTES - 1];
736+
let result = handshake.split_garbage(&mut short_buffer);
667737
assert!(matches!(result, Err(Error::CiphertextTooSmall)));
668738
}
669739
}

0 commit comments

Comments
 (0)