55
66use core:: fmt;
77
8- #[ cfg( feature = "alloc " ) ]
9- use alloc :: vec;
10- #[ cfg( feature = "alloc " ) ]
11- use alloc :: vec:: Vec ;
8+ #[ cfg( feature = "std " ) ]
9+ use std :: vec;
10+ #[ cfg( feature = "std " ) ]
11+ use std :: vec:: Vec ;
1212
1313use bitcoin:: Network ;
1414
@@ -20,11 +20,38 @@ use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
2020use tokio:: io:: { AsyncRead , AsyncReadExt , AsyncWrite , AsyncWriteExt } ;
2121
2222use crate :: {
23- Error , Handshake , PacketReader , PacketType , PacketWriter , Payload , Role ,
24- NUM_ELLIGATOR_SWIFT_BYTES , NUM_GARBAGE_TERMINTOR_BYTES , NUM_INITIAL_HANDSHAKE_BUFFER_BYTES ,
25- VERSION_CONTENT ,
23+ Error , Handshake , InboundCipher , OutboundCipher , PacketType , Role , NUM_ELLIGATOR_SWIFT_BYTES ,
24+ NUM_GARBAGE_TERMINTOR_BYTES , NUM_INITIAL_HANDSHAKE_BUFFER_BYTES , VERSION_CONTENT ,
2625} ;
2726
27+ /// A decrypted BIP324 payload with its packet type.
28+ #[ cfg( feature = "std" ) ]
29+ pub struct Payload {
30+ contents : Vec < u8 > ,
31+ packet_type : PacketType ,
32+ }
33+
34+ #[ cfg( feature = "std" ) ]
35+ impl Payload {
36+ /// Create a new payload.
37+ pub fn new ( contents : Vec < u8 > , packet_type : PacketType ) -> Self {
38+ Self {
39+ contents,
40+ packet_type,
41+ }
42+ }
43+
44+ /// Access the decrypted payload contents.
45+ pub fn contents ( & self ) -> & [ u8 ] {
46+ & self . contents
47+ }
48+
49+ /// Access the packet type.
50+ pub fn packet_type ( & self ) -> PacketType {
51+ self . packet_type
52+ }
53+ }
54+
2855/// High level error type for the protocol interface.
2956#[ cfg( feature = "std" ) ]
3057#[ derive( Debug ) ]
@@ -159,11 +186,11 @@ impl AsyncProtocol {
159186 let mut remote_ellswift_buffer = [ 0u8 ; 64 ] ;
160187 reader. read_exact ( & mut remote_ellswift_buffer) . await ?;
161188
162- let num_version_packet_bytes = PacketWriter :: required_packet_allocation ( & VERSION_CONTENT ) ;
189+ let num_version_packet_bytes = OutboundCipher :: encryption_buffer_len ( VERSION_CONTENT . len ( ) ) ;
163190 let num_decoy_packets_bytes: usize = match decoys {
164191 Some ( decoys) => decoys
165192 . iter ( )
166- . map ( |decoy| PacketWriter :: required_packet_allocation ( decoy) )
193+ . map ( |decoy| OutboundCipher :: encryption_buffer_len ( decoy. len ( ) ) )
167194 . sum ( ) ,
168195 None => 0 ,
169196 } ;
@@ -187,6 +214,8 @@ impl AsyncProtocol {
187214 // Keep pulling bytes from the buffer until the garbage is flushed.
188215 let mut remote_garbage_and_version_buffer =
189216 Vec :: with_capacity ( NUM_INITIAL_HANDSHAKE_BUFFER_BYTES ) ;
217+ let mut packet_buffer = vec ! [ 0u8 ; NUM_INITIAL_HANDSHAKE_BUFFER_BYTES ] ;
218+
190219 loop {
191220 let mut temp_buffer = [ 0u8 ; NUM_INITIAL_HANDSHAKE_BUFFER_BYTES ] ;
192221 match reader. read ( & mut temp_buffer) . await {
@@ -197,12 +226,17 @@ impl AsyncProtocol {
197226 Ok ( bytes_read) => {
198227 remote_garbage_and_version_buffer. extend_from_slice ( & temp_buffer[ ..bytes_read] ) ;
199228
200- match handshake
201- . authenticate_garbage_and_version ( & remote_garbage_and_version_buffer)
202- {
229+ match handshake. authenticate_garbage_and_version (
230+ & remote_garbage_and_version_buffer,
231+ & mut packet_buffer,
232+ ) {
203233 Ok ( ( ) ) => break ,
204234 // Not enough data, continue reading.
205235 Err ( Error :: CiphertextTooSmall ) => continue ,
236+ Err ( Error :: BufferTooSmall { required_bytes } ) => {
237+ packet_buffer. resize ( required_bytes, 0 ) ;
238+ continue ;
239+ }
206240 Err ( e) => return Err ( ProtocolError :: Internal ( e) ) ,
207241 }
208242 }
@@ -216,15 +250,15 @@ impl AsyncProtocol {
216250 }
217251 }
218252
219- let packet_handler = handshake. finalize ( ) ?;
220- let ( packet_reader , packet_writer ) = packet_handler . into_split ( ) ;
253+ let cipher_session = handshake. finalize ( ) ?;
254+ let ( inbound_cipher , outbound_cipher ) = cipher_session . into_split ( ) ;
221255
222256 Ok ( Self {
223257 reader : AsyncProtocolReader {
224- packet_reader ,
258+ inbound_cipher ,
225259 state : DecryptState :: init_reading_length ( ) ,
226260 } ,
227- writer : AsyncProtocolWriter { packet_writer } ,
261+ writer : AsyncProtocolWriter { outbound_cipher } ,
228262 } )
229263 }
230264
@@ -280,7 +314,7 @@ impl DecryptState {
280314/// Manages an async buffer to automatically decrypt contents of received packets.
281315#[ cfg( any( feature = "futures" , feature = "tokio" ) ) ]
282316pub struct AsyncProtocolReader {
283- packet_reader : PacketReader ,
317+ inbound_cipher : InboundCipher ,
284318 state : DecryptState ,
285319}
286320
@@ -297,7 +331,7 @@ impl AsyncProtocolReader {
297331 /// # Returns
298332 ///
299333 /// A `Result` containing:
300- /// * `Ok(Payload)`: A decrypted payload.
334+ /// * `Ok(Payload)`: A decrypted payload with packet type .
301335 /// * `Err(ProtocolError)`: An error that occurred during the read or decryption.
302336 pub async fn read_and_decrypt < R > ( & mut self , buffer : & mut R ) -> Result < Payload , ProtocolError >
303337 where
@@ -314,7 +348,7 @@ impl AsyncProtocolReader {
314348 * bytes_read += buffer. read ( & mut length_bytes[ * bytes_read..] ) . await ?;
315349 }
316350
317- let packet_bytes_len = self . packet_reader . decypt_len ( * length_bytes) ;
351+ let packet_bytes_len = self . inbound_cipher . decrypt_packet_len ( * length_bytes) ;
318352 self . state = DecryptState :: init_reading_payload ( packet_bytes_len) ;
319353 }
320354 DecryptState :: ReadingPayload {
@@ -325,24 +359,28 @@ impl AsyncProtocolReader {
325359 * bytes_read += buffer. read ( & mut packet_bytes[ * bytes_read..] ) . await ?;
326360 }
327361
328- let payload = self . packet_reader . decrypt_payload ( packet_bytes, None ) ?;
362+ let plaintext_len = InboundCipher :: decryption_buffer_len ( packet_bytes. len ( ) ) ;
363+ let mut plaintext_buffer = vec ! [ 0u8 ; plaintext_len] ;
364+ let packet_type =
365+ self . inbound_cipher
366+ . decrypt ( packet_bytes, & mut plaintext_buffer, None ) ?;
329367 self . state = DecryptState :: init_reading_length ( ) ;
330- return Ok ( payload ) ;
368+ return Ok ( Payload :: new ( plaintext_buffer , packet_type ) ) ;
331369 }
332370 }
333371 }
334372 }
335373
336- /// Consume the protocol reader in exchange for the underlying packet decoder .
337- pub fn decoder ( self ) -> PacketReader {
338- self . packet_reader
374+ /// Consume the protocol reader in exchange for the underlying inbound cipher .
375+ pub fn into_cipher ( self ) -> InboundCipher {
376+ self . inbound_cipher
339377 }
340378}
341379
342380/// Manages an async buffer to automatically encrypt and send contents in packets.
343381#[ cfg( any( feature = "futures" , feature = "tokio" ) ) ]
344382pub struct AsyncProtocolWriter {
345- packet_writer : PacketWriter ,
383+ outbound_cipher : OutboundCipher ,
346384}
347385
348386#[ cfg( any( feature = "futures" , feature = "tokio" ) ) ]
@@ -366,16 +404,19 @@ impl AsyncProtocolWriter {
366404 where
367405 W : AsyncWrite + Unpin + Send ,
368406 {
369- let write_bytes =
370- self . packet_writer
371- . encrypt_packet ( plaintext, None , PacketType :: Genuine ) ?;
372- buffer. write_all ( & write_bytes[ ..] ) . await ?;
407+ let packet_len = OutboundCipher :: encryption_buffer_len ( plaintext. len ( ) ) ;
408+ let mut packet_buffer = vec ! [ 0u8 ; packet_len] ;
409+
410+ self . outbound_cipher
411+ . encrypt ( plaintext, & mut packet_buffer, PacketType :: Genuine , None ) ?;
412+
413+ buffer. write_all ( & packet_buffer) . await ?;
373414 buffer. flush ( ) . await ?;
374415 Ok ( ( ) )
375416 }
376417
377- /// Consume the protocol writer in exchange for the underlying packet encoder .
378- pub fn encoder ( self ) -> PacketWriter {
379- self . packet_writer
418+ /// Consume the protocol writer in exchange for the underlying outbound cipher .
419+ pub fn into_cipher ( self ) -> OutboundCipher {
420+ self . outbound_cipher
380421 }
381422}
0 commit comments