From b111e9dd6b00bed45010a6234945518b7a7fcb72 Mon Sep 17 00:00:00 2001 From: rustaceanrob Date: Tue, 29 Jul 2025 11:18:22 +0100 Subject: [PATCH 1/2] Gut `p2p` and use new `p2p` changes --- Cargo.toml | 2 +- p2p/Cargo.toml | 16 +- p2p/examples/feeler.rs | 7 +- p2p/examples/update_accumulator.rs | 112 ---------- p2p/src/lib.rs | 142 +++---------- p2p/src/net.rs | 1 + p2p/src/tokio_ext.rs | 314 ----------------------------- p2p/tests/tests.rs | 14 -- peers/Cargo.toml | 2 +- peers/tests/test.rs | 1 + utxo_verifier/Cargo.toml | 2 +- 11 files changed, 34 insertions(+), 579 deletions(-) delete mode 100644 p2p/examples/update_accumulator.rs delete mode 100644 p2p/src/tokio_ext.rs diff --git a/Cargo.toml b/Cargo.toml index e83c6df..acbe5ad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,4 +4,4 @@ default-members = ["accumulator", "p2p", "peers", "utxo_verifier"] resolver = "3" [workspace.dependencies] -bitcoin = { git = "https://github.com/rust-bitcoin/rust-bitcoin", default-features = false, rev = "2bb9bb6bc99ba07ed3d543a512ec3d2a9462770d" } +bitcoin = { git = "https://github.com/rustaceanrob/rust-bitcoin", default-features = false, rev = "9f201781d9110be98757dc1c6782bc91c8a2c501" } diff --git a/p2p/Cargo.toml b/p2p/Cargo.toml index f978c4a..43a360d 100644 --- a/p2p/Cargo.toml +++ b/p2p/Cargo.toml @@ -10,31 +10,17 @@ rust-version = "1.75.0" [dependencies] bitcoin = { workspace = true, features = ["rand-std"] } -p2p = { package = "bitcoin-p2p-messages", git = "https://github.com/rust-bitcoin/rust-bitcoin", rev = "2bb9bb6bc99ba07ed3d543a512ec3d2a9462770d" } -tokio = { version = "1", default-features = false, optional = true, features = [ - "sync", - "io-util", - "time", - "net", -]} +p2p = { package = "bitcoin-p2p-messages", git = "https://github.com/rustaceanrob/rust-bitcoin", rev = "9f201781d9110be98757dc1c6782bc91c8a2c501" } [dev-dependencies] corepc-node = { version = "0.8.0", default-features = false, features = [ "29_0", "download" ] } -tokio = { version = "1", default-features = false, features = ["full"] } tracing = "0.1" tracing-subscriber = "0.3" accumulator = { path = "../accumulator/" } peers = { path = "../peers/", default-features = true } -[features] -default = ["tokio"] -tokio = ["dep:tokio"] - -[[example]] -name = "update_accumulator" - [[example]] name = "feeler" diff --git a/p2p/examples/feeler.rs b/p2p/examples/feeler.rs index 06c6ddd..62d9a6d 100644 --- a/p2p/examples/feeler.rs +++ b/p2p/examples/feeler.rs @@ -4,6 +4,7 @@ use bitcoin::{ secp256k1::rand::{seq::SliceRandom, thread_rng}, Network, }; +use p2p::message_network::UserAgent; use peers::{dns::DnsQuery, PortExt, SeedsExt}; use swiftsync_p2p::{net::ConnectionExt, ConnectionBuilder}; @@ -19,15 +20,15 @@ fn main() { let socket_addr = SocketAddr::new(any, NETWORK.port()); tracing::info!("Attempting a connection"); let connection = ConnectionBuilder::new() - .set_user_agent("/bitcoin-feeler:0.1.0".to_string()) + .set_user_agent(UserAgent::from_nonstandard("bitcoin-feeler")) .connection_timeout(Duration::from_millis(3500)) .change_network(NETWORK) .open_feeler(socket_addr); match connection { Ok(f) => { tracing::info!( - "Connection successful: Advertised protocol version {}, Adveristed services {}", - f.protocol_version.0, + "Connection successful: Advertised protocol version {:?}, Adveristed services {}", + f.protocol_version, f.services ); } diff --git a/p2p/examples/update_accumulator.rs b/p2p/examples/update_accumulator.rs deleted file mode 100644 index d564e6a..0000000 --- a/p2p/examples/update_accumulator.rs +++ /dev/null @@ -1,112 +0,0 @@ -use std::{ - net::{IpAddr, Ipv4Addr, SocketAddr}, - time::Instant, -}; - -use accumulator::Accumulator; -use bitcoin::{ - block::BlockUncheckedExt, - secp256k1::rand::{seq::SliceRandom, thread_rng}, - BlockHash, Network, OutPoint, -}; -use p2p::{ - message::NetworkMessage, - message_blockdata::{GetBlocksMessage, Inventory}, - ServiceFlags, -}; -use peers::{ - dns::{DnsQuery, TokioDnsExt}, - PortExt, -}; -use swiftsync_p2p::{ - tokio_ext::{TokioConnectionExt, TokioReadNetworkMessageExt, TokioWriteNetworkMessageExt}, - ConnectionBuilder, -}; - -const DNS_SEED: &str = "seed.bitcoin.sprovoost.nl"; -const NETWORK: Network = Network::Bitcoin; -const CLOUDFLARE: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)), 53); -const START_HEIGHT: i32 = 900_000; - -#[tokio::main] -async fn main() { - let subscriber = tracing_subscriber::FmtSubscriber::new(); - tracing::subscriber::set_global_default(subscriber).unwrap(); - - let mut acc = Accumulator::new(); - let locator_hash = "000000000000000000010538edbfd2d5b809a33dd83f284aeea41c6d0d96968a" - .parse::() - .unwrap(); - let zero = BlockHash::from_byte_array([0u8; 32]); - tracing::info!("Configuring connection requirements"); - let connection_builder = ConnectionBuilder::new() - .change_network(NETWORK) - .add_start_height(START_HEIGHT) - .set_user_agent("/example-accumulator:0.1.0".to_string()) - .no_cmpct_blocks() - .announce_by_inv() - .their_services_expected(ServiceFlags::NETWORK); - tracing::info!("Querying DNS for peers"); - let dns = DnsQuery::new(DNS_SEED, CLOUDFLARE) - .lookup_async() - .await - .unwrap(); - tracing::info!("Connecting to the first result"); - let first = dns.choose(&mut thread_rng()).unwrap(); - let peer = SocketAddr::new(*first, NETWORK.port()); - let (mut stream, mut ctx) = connection_builder.open_connection(peer).await.unwrap(); - tracing::info!("Completed version handshake"); - let get_blocks_request = GetBlocksMessage::new(vec![locator_hash], zero); - let message = NetworkMessage::GetBlocks(get_blocks_request); - tracing::info!("Requesting blocks"); - stream.write_message(message, &mut ctx).await.unwrap(); - tracing::info!("Waiting for response"); - loop { - let response = stream.read_message(&mut ctx).await.unwrap(); - if let Some(message) = response { - match message { - NetworkMessage::Ping(nonce) => stream - .write_message(NetworkMessage::Pong(nonce), &mut ctx) - .await - .unwrap(), - NetworkMessage::Inv(data) => { - if data - .0 - .iter() - .any(|inv| matches!(inv, Inventory::Block(_) | Inventory::WitnessBlock(_))) - { - let getdata = NetworkMessage::GetData(data); - stream.write_message(getdata, &mut ctx).await.unwrap(); - } - } - NetworkMessage::Block(block) => { - let checked = block.validate().unwrap(); - let hash = checked.block_hash(); - tracing::info!("Validated block: {hash}"); - let now = Instant::now(); - tracing::info!("Updating the accumulator"); - for tx in checked.transactions() { - for input in &tx.input { - let outpoint = input.previous_output; - acc.spend(outpoint); - } - let txid = tx.compute_txid(); - for ind in 0..tx.output.len() { - let outpoint = OutPoint { - txid, - vout: ind as u32, - }; - acc.add(outpoint); - } - } - tracing::info!( - "Updated accumulator in {} milliseconds", - now.elapsed().as_millis() - ); - return; - } - other => tracing::info!("{}", other.cmd()), - } - } - } -} diff --git a/p2p/src/lib.rs b/p2p/src/lib.rs index 2eaf25b..1a1ca90 100644 --- a/p2p/src/lib.rs +++ b/p2p/src/lib.rs @@ -5,17 +5,14 @@ use std::{ use bitcoin::{consensus, FeeRate, Network}; use p2p::{ - message::{CommandString, NetworkMessage, RawNetworkMessage}, - message_network::VersionMessage, - Address, Magic, ServiceFlags, + message::{NetworkMessage, RawNetworkMessage}, + message_network::{UserAgent, VersionMessage}, + Address, Magic, ProtocolVersion, ServiceFlags, }; use validation::ValidationExt; /// Extension traits for `std` networking tools. pub mod net; -/// Extension traits for use with the `tokio` asynchronous runtime framework. -#[cfg(feature = "tokio")] -pub mod tokio_ext; mod validation; @@ -26,33 +23,6 @@ pub const DEFAULT_USER_AGENT: &str = "/swiftsync:0.1.0/"; const LOCAL_HOST: Ipv4Addr = Ipv4Addr::new(0, 0, 0, 0); const UNREACHABLE: SocketAddr = SocketAddr::V4(SocketAddrV4::new(LOCAL_HOST, 0)); -/// A version of the Bitcoin peer-to-peer messages. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, std::hash::Hash)] -pub struct ProtocolVerison(pub u32); - -impl ProtocolVerison { - /// Support for relaying transactions by WTXID - pub const WTXID_RELAY: ProtocolVerison = ProtocolVerison(70016); - /// Invalid compact blocks are not a ban - pub const NO_BAN_CMPCT: ProtocolVerison = ProtocolVerison(70015); - /// Compact block message support - pub const CMPCT_BLOCKS: ProtocolVerison = ProtocolVerison(70014); - /// Support the `feefilter` message - pub const FEE_FILTER: ProtocolVerison = ProtocolVerison(70013); - /// Support `sendheaders` message to advertise new blocks with `header` messages - pub const SEND_HEADERS: ProtocolVerison = ProtocolVerison(70012); - /// Support NODE_BLOOM messages and do not support bloom filter messages if not set - pub const NODE_BLOOM: ProtocolVerison = ProtocolVerison(70011); - /// Support `reject` messages - pub const REJECT: ProtocolVerison = ProtocolVerison(70002); - /// Support bloom filter messages - pub const BLOOM_FILTERS: ProtocolVerison = ProtocolVerison(70001); - /// Support `mempool` messages - pub const MEMPOOL: ProtocolVerison = ProtocolVerison(60002); - /// Support `ping` and `pong` messages - pub const PING_PONG: ProtocolVerison = ProtocolVerison(60001); -} - /// The context for the connection. This includes data like the current cipher state, their /// services offered, their fee filter, last message time, and more. #[derive(Debug)] @@ -67,7 +37,7 @@ impl ConnectionContext { read_half: ReadHalf, negotiation: Negotiation, their_services: ServiceFlags, - their_version: ProtocolVerison, + their_version: ProtocolVersion, ) -> Self { let read_ctx = ReadContext { read_half, @@ -173,7 +143,6 @@ impl ReadContext { fn is_valid(&self, message: &NetworkMessage) -> bool { match &message { - NetworkMessage::FeeFilter(f) => *f > 0, NetworkMessage::Headers(h) => h.is_valid(), NetworkMessage::GetData(r) => r.0.is_valid(), NetworkMessage::Inv(r) => r.0.is_valid(), @@ -184,10 +153,6 @@ impl ReadContext { fn update_metadata(&mut self, message: &NetworkMessage) { self.last_message = Instant::now(); match &message { - NetworkMessage::FeeFilter(f) => { - let fee_rate = FeeRate::from_sat_per_kwu(*f as u32 / 4); - self.fee_filter = fee_rate; - } NetworkMessage::Alert(_) => self.final_alert = true, NetworkMessage::SendHeaders => { self.negotiation.send_headers.them = true; @@ -209,7 +174,7 @@ pub struct WriteContext { write_half: WriteHalf, negotiation: Negotiation, their_services: ServiceFlags, - their_protocol_verison: ProtocolVerison, + their_protocol_verison: ProtocolVersion, } impl WriteContext { @@ -263,11 +228,11 @@ pub struct ConnectionBuilder { our_ip: SocketAddr, offered_services: ServiceFlags, their_services: ServiceFlags, - our_version: ProtocolVerison, - their_version: ProtocolVerison, + our_version: ProtocolVersion, + their_version: ProtocolVersion, offer: Offered, start_height: i32, - user_agent: String, + user_agent: UserAgent, tcp_timeout: Duration, } @@ -279,11 +244,11 @@ impl ConnectionBuilder { our_ip: UNREACHABLE, offered_services: ServiceFlags::NONE, their_services: ServiceFlags::NONE, - our_version: ProtocolVerison::WTXID_RELAY, - their_version: ProtocolVerison::WTXID_RELAY, + our_version: ProtocolVersion::WTXID_RELAY_VERSION, + their_version: ProtocolVersion::WTXID_RELAY_VERSION, offer: Offered::default(), start_height: 0, - user_agent: DEFAULT_USER_AGENT.to_string(), + user_agent: UserAgent::from_nonstandard(DEFAULT_USER_AGENT), tcp_timeout: Duration::from_secs(2), } } @@ -305,7 +270,7 @@ impl ConnectionBuilder { } /// Downgrade your advertised version. - pub fn downgrade_to_version(self, us: ProtocolVerison) -> Self { + pub fn downgrade_to_version(self, us: ProtocolVersion) -> Self { Self { our_version: us, ..self @@ -313,7 +278,7 @@ impl ConnectionBuilder { } /// Accept a minimum version. - pub fn accept_minimum_version(self, them: ProtocolVerison) -> Self { + pub fn accept_minimum_version(self, them: ProtocolVersion) -> Self { Self { their_version: them, ..self @@ -337,7 +302,7 @@ impl ConnectionBuilder { } /// Set the user agent sent as part of your version message. - pub fn set_user_agent(self, user_agent: String) -> Self { + pub fn set_user_agent(self, user_agent: UserAgent) -> Self { Self { user_agent, ..self } } @@ -435,40 +400,16 @@ impl Default for Offered { #[derive(Debug, Clone, Copy)] pub struct Feeler { pub services: ServiceFlags, - pub protocol_version: ProtocolVerison, -} - -pub(crate) struct MessageHeader { - magic: Magic, - _command: CommandString, - length: u32, - _checksum: u32, -} - -impl consensus::Decodable for MessageHeader { - fn consensus_decode( - reader: &mut R, - ) -> Result { - let magic = Magic::consensus_decode(reader)?; - let _command = CommandString::consensus_decode(reader)?; - let length = u32::consensus_decode(reader)?; - let _checksum = u32::consensus_decode(reader)?; - Ok(Self { - magic, - _command, - length, - _checksum, - }) - } + pub protocol_version: ProtocolVersion, } fn make_version( - version: ProtocolVerison, + version: ProtocolVersion, our_services: ServiceFlags, their_services: ServiceFlags, our_ip: SocketAddr, start_height: i32, - user_agent: String, + user_agent: UserAgent, nonce: u64, ) -> VersionMessage { let now = SystemTime::now() @@ -478,7 +419,7 @@ fn make_version( let them = Address::new(&UNREACHABLE, their_services); let us = Address::new(&our_ip, our_services); VersionMessage { - version: version.0, + version, services: our_services, timestamp: now, receiver: them, @@ -494,22 +435,14 @@ fn make_version( fn interpret_first_message( message: NetworkMessage, nonce: u64, - their_expected_version: ProtocolVerison, + their_expected_version: ProtocolVersion, their_expected_services: ServiceFlags, -) -> Result<(ProtocolVerison, ServiceFlags), HandshakeError> { +) -> Result<(ProtocolVersion, ServiceFlags), HandshakeError> { if let NetworkMessage::Version(version) = message { if version.nonce.eq(&nonce) { return Err(HandshakeError::ConnectedToSelf); } - if version.version < their_expected_version.0 { - return Err(HandshakeError::TooLowVersion(ProtocolVerison( - version.version, - ))); - } - if !version.services.has(their_expected_services) { - return Err(HandshakeError::UnsupportedFeature); - } - Ok((ProtocolVerison(version.version), version.services)) + Ok((version.version, version.services)) } else { Err(HandshakeError::IrrelevantMessage(message)) } @@ -564,7 +497,7 @@ impl std::error::Error for ParseMessageError {} #[derive(Debug, Clone)] pub enum HandshakeError { /// Their version is too low for the configured preferences. - TooLowVersion(ProtocolVerison), + TooLowVersion(ProtocolVersion), /// Some message was sent before the handshake completed. IrrelevantMessage(NetworkMessage), /// This is a connection to self. @@ -591,7 +524,7 @@ impl std::fmt::Display for HandshakeError { "a feature we require is not supported by the connection." ), Self::TooLowVersion(version) => { - write!(f, "the remote peer had a too-low version: {}", version.0) + write!(f, "the remote peer had a too-low version: {:?}", version) } } } @@ -610,7 +543,7 @@ macro_rules! define_read_message_logic { let mut message_buf = vec![0_u8; 24]; read!(&mut message_buf)?; - let header: $crate::MessageHeader = consensus::deserialize_partial(&message_buf) + let header: V1MessageHeader = consensus::deserialize_partial(&message_buf) .map_err(ParseMessageError::Consensus)? .0; if header.magic != $magic { @@ -741,26 +674,12 @@ macro_rules! define_version_message_logic { }}; } -#[cfg(feature = "tokio")] -macro_rules! async_awaiter { - ($e:expr) => { - $e.await - }; -} - macro_rules! blocking_awaiter { ($e:expr) => { $e }; } -#[cfg(feature = "tokio")] -macro_rules! read_message_async { - ($reader:expr, $magic:expr) => { - $crate::define_read_message_logic!(async_awaiter, $reader, $magic) - }; -} - macro_rules! read_message_blocking { ($reader:expr, $magic:expr) => { $crate::define_read_message_logic!(blocking_awaiter, $reader, $magic) @@ -773,19 +692,6 @@ macro_rules! version_handshake_blocking { }; } -#[cfg(feature = "tokio")] -macro_rules! version_handshake_async { - ($reader:expr, $conn:ident) => { - $crate::define_version_message_logic!(async_awaiter, $reader, $conn) - }; -} - -#[cfg(feature = "tokio")] -pub(crate) use async_awaiter; pub(crate) use blocking_awaiter; -#[cfg(feature = "tokio")] -pub(crate) use read_message_async; pub(crate) use read_message_blocking; -#[cfg(feature = "tokio")] -pub(crate) use version_handshake_async; pub(crate) use version_handshake_blocking; diff --git a/p2p/src/net.rs b/p2p/src/net.rs index d844d85..dc3ac50 100644 --- a/p2p/src/net.rs +++ b/p2p/src/net.rs @@ -8,6 +8,7 @@ use bitcoin::secp256k1::rand; use p2p::message::NetworkMessage; use p2p::message::RawNetworkMessage; use p2p::message_compact_blocks::SendCmpct; +use p2p::message::V1MessageHeader; use p2p::Magic; use crate::{ diff --git a/p2p/src/tokio_ext.rs b/p2p/src/tokio_ext.rs deleted file mode 100644 index e4e38fb..0000000 --- a/p2p/src/tokio_ext.rs +++ /dev/null @@ -1,314 +0,0 @@ -use ::std::fmt::{Debug, Display}; -use std::net::SocketAddr; - -use bitcoin::consensus; -use bitcoin::secp256k1::rand; -use p2p::message::{NetworkMessage, RawNetworkMessage}; -use p2p::message_compact_blocks::SendCmpct; -use p2p::Magic; -use tokio::io::AsyncWriteExt; -use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; -use tokio::{ - io::{self, AsyncReadExt}, - net::TcpStream, -}; - -use crate::{ - async_awaiter, interpret_first_message, make_version, version_handshake_async, - ConnectionBuilder, ConnectionContext, Feeler, HandshakeError, Negotiation, ParseMessageError, - ReadContext, ReadHalf, WriteContext, WriteHalf, -}; - -/// Connect to peers using `tokio`. -pub trait TokioConnectionExt { - type Error: Debug + Display + Send + Sync + std::error::Error; - - /// Open a TCP connection to a peer. - #[allow(async_fn_in_trait)] - async fn open_connection( - self, - to: impl Into, - ) -> Result<(TcpStream, ConnectionContext), Self::Error>; - - /// Start a handshake with a pre-existing connection. Normally used after establishing a Socks5 - /// proxy connection. - #[allow(async_fn_in_trait)] - async fn start_handshake( - self, - tcp_stream: TcpStream, - ) -> Result<(TcpStream, ConnectionContext), Self::Error>; - - /// Open a feeler to test a node's liveliness - #[allow(async_fn_in_trait)] - async fn open_feeler(self, to: impl Into) -> Result; -} - -impl TokioConnectionExt for ConnectionBuilder { - type Error = ConnectionError; - - async fn open_connection( - mut self, - to: impl Into, - ) -> Result<(TcpStream, ConnectionContext), Self::Error> { - let socket_addr = to.into(); - let timeout = tokio::time::timeout(self.tcp_timeout, TcpStream::connect(socket_addr)).await; - let mut tcp_stream = - timeout.map_err(|_| ConnectionError::Protocol(HandshakeError::Timeout))??; - version_handshake_async!(tcp_stream, self) - } - - async fn start_handshake( - mut self, - mut tcp_stream: TcpStream, - ) -> Result<(TcpStream, ConnectionContext), Self::Error> { - version_handshake_async!(tcp_stream, self) - } - - async fn open_feeler(mut self, to: impl Into) -> Result { - let socket_addr = to.into(); - let timeout = tokio::time::timeout(self.tcp_timeout, TcpStream::connect(socket_addr)).await; - let mut tcp_stream = - timeout.map_err(|_| ConnectionError::Protocol(HandshakeError::Timeout))??; - let res: Result<(TcpStream, ConnectionContext), Self::Error> = - version_handshake_async!(tcp_stream, self); - let (_, ctx) = res?; - let services = ctx.write_ctx.their_services; - let protocol_version = ctx.write_ctx.their_protocol_verison; - Ok(Feeler { - services, - protocol_version, - }) - } -} - -async fn write_message( - write: &mut W, - message: NetworkMessage, - write_half: &mut WriteHalf, -) -> Result<(), io::Error> { - let msg_bytes = write_half.serialize_message(message); - write.write_all(&msg_bytes).await?; - write.flush().await?; - Ok(()) -} - -trait TokioTransportExt { - #[allow(async_fn_in_trait)] - async fn read_message( - &mut self, - reader: &mut R, - ) -> Result, ReadError>; -} - -impl TokioTransportExt for ReadHalf { - async fn read_message( - &mut self, - reader: &mut R, - ) -> Result, ReadError> { - match self { - Self::V1(magic) => { - crate::read_message_async!(reader, *magic) - } - } - } -} - -/// Write bitcoin network messages directly over `tokio` TCP streams. -pub trait TokioWriteNetworkMessageExt { - /// Write a message with the current context. - #[allow(async_fn_in_trait)] - async fn write_message( - &mut self, - message: NetworkMessage, - ctx: impl AsMut, - ) -> Result<(), WriteError>; -} - -impl TokioWriteNetworkMessageExt for TcpStream { - fn write_message( - &mut self, - message: NetworkMessage, - ctx: impl AsMut, - ) -> impl std::future::Future> { - write_for_any(self, message, ctx) - } -} - -impl TokioWriteNetworkMessageExt for OwnedWriteHalf { - fn write_message( - &mut self, - message: NetworkMessage, - ctx: impl AsMut, - ) -> impl std::future::Future> { - write_for_any(self, message, ctx) - } -} - -async fn write_for_any( - writer: &mut W, - message: NetworkMessage, - mut ctx: impl AsMut, -) -> Result<(), WriteError> { - let ctx = ctx.as_mut(); - if !ctx.ok_to_send(&message) { - return Err(WriteError::NotRecommended(message)); - }; - write_message(writer, message, &mut ctx.write_half).await?; - Ok(()) -} - -/// Read a message directly off a TCP stream. -pub trait TokioReadNetworkMessageExt { - /// Try to read a message and error otherwise. - /// - /// This method performs some light validation to ensure the node is not sending spam or - /// non-sensical messages. - #[allow(async_fn_in_trait)] - async fn read_message( - &mut self, - ctx: impl AsMut, - ) -> Result, ReadError>; -} - -impl TokioReadNetworkMessageExt for TcpStream { - async fn read_message( - &mut self, - mut rtx: impl AsMut, - ) -> Result, ReadError> { - let ctx = rtx.as_mut(); - let message = ctx.read_half.read_message(self).await?; - match message { - Some(message) => { - if !ctx.ok_to_recv_message(&message) { - return Err(ReadError::NonsenseMessage(message)); - } - if !ctx.is_valid(&message) { - return Err(ReadError::ParseMessageError(ParseMessageError::Malformed)); - } - ctx.update_metadata(&message); - Ok(Some(message)) - } - None => Ok(None), - } - } -} - -impl TokioReadNetworkMessageExt for OwnedReadHalf { - async fn read_message( - &mut self, - mut rtx: impl AsMut, - ) -> Result, ReadError> { - let ctx = rtx.as_mut(); - let message = ctx.read_half.read_message(self).await?; - match message { - Some(message) => { - if !ctx.ok_to_recv_message(&message) { - return Err(ReadError::NonsenseMessage(message)); - } - if !ctx.is_valid(&message) { - return Err(ReadError::ParseMessageError(ParseMessageError::Malformed)); - } - ctx.update_metadata(&message); - Ok(Some(message)) - } - None => Ok(None), - } - } -} - -// Error implementation section - -/// Errors that may occur when starting a connection. -#[derive(Debug)] -pub enum ConnectionError { - /// Read or write failure. - Io(io::Error), - /// The handshake failed to malformed messages or a mismatch in preferences. - Protocol(HandshakeError), - /// A message that was read violated the protocol. - Reader(ReadError), -} - -impl From for ConnectionError { - fn from(value: io::Error) -> Self { - Self::Io(value) - } -} - -impl From for ConnectionError { - fn from(value: ReadError) -> Self { - Self::Reader(value) - } -} - -impl Display for ConnectionError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Io(io) => write!(f, "{io}"), - Self::Protocol(proto) => write!(f, "{proto}"), - Self::Reader(read) => write!(f, "{read}"), - } - } -} - -impl std::error::Error for ConnectionError {} - -/// Errors when attempting to write a message. -#[derive(Debug)] -pub enum WriteError { - /// Writing to the stream failed. - Io(io::Error), - /// The message is invalid or not supported. - NotRecommended(NetworkMessage), -} - -impl From for WriteError { - fn from(value: io::Error) -> Self { - WriteError::Io(value) - } -} - -impl Display for WriteError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Io(io) => write!(f, "{io}"), - Self::NotRecommended(msg) => write!(f, "non-sensical message: {}", msg.cmd()), - } - } -} - -/// Errors when reading messages off of the stream. -#[derive(Debug)] -pub enum ReadError { - /// The message violates the protocol. Normally, these are deprecated messages or messages that - /// should have been sent during the handshake. - NonsenseMessage(NetworkMessage), - /// Parsing a message failed. - ParseMessageError(ParseMessageError), - /// The stream was closed. - Io(io::Error), -} - -impl Display for ReadError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::ParseMessageError(r) => write!(f, "{r}"), - Self::Io(io) => write!(f, "{io}"), - Self::NonsenseMessage(n) => write!(f, "{}", n.cmd()), - } - } -} - -impl std::error::Error for ReadError {} - -impl From for ReadError { - fn from(value: ParseMessageError) -> Self { - Self::ParseMessageError(value) - } -} - -impl From for ReadError { - fn from(value: io::Error) -> Self { - Self::Io(value) - } -} diff --git a/p2p/tests/tests.rs b/p2p/tests/tests.rs index 32fb64f..0a19c98 100644 --- a/p2p/tests/tests.rs +++ b/p2p/tests/tests.rs @@ -91,17 +91,3 @@ fn enforces_desired_services() { assert!(ok.is_ok()); bitcoind.stop().unwrap(); } - -// Tokio tests - -#[tokio::test] -async fn does_handshake_async() { - use swiftsync_p2p::tokio_ext::TokioConnectionExt; - let (mut bitcoind, socket_addr) = TestNodeBuilder::new().start(); - let _ = ConnectionBuilder::new() - .change_network(Network::Regtest) - .open_connection(socket_addr) - .await - .unwrap(); - bitcoind.stop().unwrap(); -} diff --git a/peers/Cargo.toml b/peers/Cargo.toml index bf791a0..a192941 100644 --- a/peers/Cargo.toml +++ b/peers/Cargo.toml @@ -14,5 +14,5 @@ tokio = { version = "1", default-features = false, optional = true, features = [ tokio = { version = "1", default-features = false, features = ["full"] } [features] -default = ["tokio"] +default = [] tokio = ["dep:tokio"] diff --git a/peers/tests/test.rs b/peers/tests/test.rs index 98b152e..887e6dd 100644 --- a/peers/tests/test.rs +++ b/peers/tests/test.rs @@ -3,6 +3,7 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use peers::dns::DnsQuery; #[tokio::test] +#[cfg(feature = "tokio")] async fn test_tokio_dns_ext() { use peers::dns::TokioDnsExt; let resolver = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(46, 166, 189, 67)), 53); diff --git a/utxo_verifier/Cargo.toml b/utxo_verifier/Cargo.toml index 421d2f5..ce438a9 100644 --- a/utxo_verifier/Cargo.toml +++ b/utxo_verifier/Cargo.toml @@ -8,7 +8,7 @@ accumulator = { path = "../accumulator/" } bitcoin = { workspace = true, features = ["rand-std"] } peers = { path = "../peers/", default-features = false } p2p = { path = "../p2p/", package = "swiftsync-p2p", default-features = false } -p2p-messages = { package = "bitcoin-p2p-messages", git = "https://github.com/rust-bitcoin/rust-bitcoin", rev = "2bb9bb6bc99ba07ed3d543a512ec3d2a9462770d" } +p2p-messages = { package = "bitcoin-p2p-messages", git = "https://github.com/rustaceanrob/rust-bitcoin", rev = "9f201781d9110be98757dc1c6782bc91c8a2c501" } tracing = "0.1" tracing-subscriber = "0.3" rusqlite = { version = "0.36.0", features = ["bundled"] } From ad364b19f7aac4ff0f682de8de470871ed00e73a Mon Sep 17 00:00:00 2001 From: rustaceanrob Date: Tue, 29 Jul 2025 15:05:46 +0100 Subject: [PATCH 2/2] WIP handshake --- p2p/src/handshake.rs | 424 +++++++++++++++++++++++++++++++++++++++++++ p2p/src/lib.rs | 1 + 2 files changed, 425 insertions(+) create mode 100644 p2p/src/handshake.rs diff --git a/p2p/src/handshake.rs b/p2p/src/handshake.rs new file mode 100644 index 0000000..ba57591 --- /dev/null +++ b/p2p/src/handshake.rs @@ -0,0 +1,424 @@ +use std::{ + cmp::min, + fmt::Display, + net::{Ipv4Addr, SocketAddrV4}, + sync::atomic::{self, AtomicBool}, + time::{SystemTime, UNIX_EPOCH}, +}; + +use bitcoin::Network; +use p2p::{ + message::{CommandString, NetworkMessage}, + message_compact_blocks::SendCmpct, + message_network::{Alert, ClientSoftwareVersion, UserAgent, UserAgentVersion, VersionMessage}, + ProtocolVersion, ServiceFlags, +}; + +const CLIENT_VERSION: ClientSoftwareVersion = ClientSoftwareVersion::SemVer { + major: 0, + minor: 1, + revision: 0, +}; +const AGENT_VERSION: UserAgentVersion = UserAgentVersion::new(CLIENT_VERSION); +const CLIENT_NAME: &str = "SwiftSync"; +const LOCAL_HOST_IP: Ipv4Addr = Ipv4Addr::new(0, 0, 0, 0); +const LOCAL_HOST: SocketAddrV4 = SocketAddrV4::new(LOCAL_HOST_IP, 0); +const NETWORK: Network = Network::Bitcoin; + +#[derive(Debug, Clone)] +struct ConnectionConfig { + our_version: ProtocolVersion, + our_services: ServiceFlags, + expected_version: ProtocolVersion, + expected_services: ServiceFlags, + send_cmpct: SendCmpct, + user_agent: UserAgent, + network: Network, +} + +impl ConnectionConfig { + fn new() -> Self { + let user_agent = UserAgent::new(CLIENT_NAME, AGENT_VERSION); + Self { + our_version: ProtocolVersion::WTXID_RELAY_VERSION, + our_services: ServiceFlags::NONE, + expected_version: ProtocolVersion::WTXID_RELAY_VERSION, + expected_services: ServiceFlags::NONE, + send_cmpct: SendCmpct { + send_compact: false, + version: 0, + }, + user_agent, + network: NETWORK, + } + } + + fn change_network(mut self, network: Network) -> Self { + self.network = network; + self + } + + fn decrease_version_requirement(mut self, protocol_version: ProtocolVersion) -> Self { + self.expected_version = protocol_version; + self + } + + fn set_service_requirement(mut self, service_flags: ServiceFlags) -> Self { + self.expected_services = service_flags; + self + } + + fn offer_services(mut self, service_flags: ServiceFlags) -> Self { + self.our_services = service_flags; + self + } + + fn user_agent(mut self, user_agent: UserAgent) -> Self { + self.user_agent = user_agent; + self + } + + fn send_cmpct(mut self, send_cmpct: SendCmpct) -> Self { + self.send_cmpct = send_cmpct; + self + } + + fn start_handshake( + self, + network_message: NetworkMessage, + origin: Origin, + ) -> Result { + // The first message must always be a `version` + let version = match network_message { + NetworkMessage::Version(version) => version, + e => return Err(VersionError::IrrelevantMessage(e.command())), + }; + let time_received = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time went backwards"); + let mut suggested_messages = Vec::new(); + // Add the version message to the stack if we are receiving a version + if let Origin::Inbound { nonce, version } = origin { + if version.nonce.eq(&nonce) { + return Err(VersionError::ConnectionToSelf); + } + suggested_messages.push(NetworkMessage::Version(version)); + } + // Reject an incompatible peer based on our requirements + if version.version < self.expected_version + || version.version < ProtocolVersion::MIN_PEER_PROTO_VERSION + { + return Err(VersionError::TooLowVersion(version.version)); + } + if !version.services.has(self.expected_services) { + return Err(VersionError::NotEnoughServices(version.services)); + } + // Prepare messages to send based on the effective version + let effective_version = min(self.our_version, version.version); + if effective_version > ProtocolVersion::WTXID_RELAY_VERSION { + suggested_messages.push(NetworkMessage::WtxidRelay); + } + // Weird case where this number is not a constant in Bitcoin Core + if effective_version > ProtocolVersion::from_nonstandard(70016) { + suggested_messages.push(NetworkMessage::SendAddrV2); + } + if effective_version > ProtocolVersion::SENDHEADERS_VERSION { + suggested_messages.push(NetworkMessage::SendHeaders); + } else { + suggested_messages.push(NetworkMessage::Alert(Alert::final_alert())); + } + let net_time_diff = time_received.as_secs_f64() as i64 - version.timestamp; + let metadata = ConnectionMetadata { + lowest_common_version: effective_version, + their_services: version.services, + net_time_difference: net_time_diff, + reported_height: version.start_height, + prefers_addrv2: AtomicBool::new(false), + prefers_wtxid: AtomicBool::new(false), + prefers_headers: AtomicBool::new(false), + prefers_cmpct: AtomicBool::new(false), + }; + let initial_handshake = InitializedHandshake { + metadata, + suggested_messages, + }; + Ok(initial_handshake) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum Origin { + Inbound { nonce: u64, version: VersionMessage }, + OutBound, +} + +impl Default for ConnectionConfig { + fn default() -> Self { + Self::new() + } +} + +#[derive(Debug)] +struct ConnectionMetadata { + lowest_common_version: ProtocolVersion, + their_services: ServiceFlags, + net_time_difference: i64, + reported_height: i32, + prefers_headers: AtomicBool, + prefers_wtxid: AtomicBool, + prefers_addrv2: AtomicBool, + prefers_cmpct: AtomicBool, +} + +// States + +#[derive(Debug)] +struct InitializedHandshake { + metadata: ConnectionMetadata, + suggested_messages: Vec, +} + +impl InitializedHandshake { + fn take_suggested(&mut self) -> impl Iterator { + core::mem::take(&mut self.suggested_messages).into_iter() + } + + fn negotiate( + self, + network_message: NetworkMessage, + ) -> Result { + match &network_message { + NetworkMessage::Verack => Ok(NegotiationUpdate::Finished(CompletedHandshake { + metadata: self.metadata, + suggested_messages: Vec::new(), + })), + NetworkMessage::Alert(_) => { + if self.metadata.lowest_common_version > ProtocolVersion::INVALID_CB_NO_BAN_VERSION + { + return Err(NegotiationError::UnexpectedMessage( + network_message.command(), + )); + } + Ok(NegotiationUpdate::Updated(self)) + } + NetworkMessage::WtxidRelay => { + if self.metadata.lowest_common_version < ProtocolVersion::WTXID_RELAY_VERSION { + return Err(NegotiationError::UnexpectedMessage( + network_message.command(), + )); + } + self.metadata + .prefers_wtxid + .store(true, atomic::Ordering::Relaxed); + Ok(NegotiationUpdate::Updated(self)) + } + NetworkMessage::SendAddrV2 => { + if self.metadata.lowest_common_version < ProtocolVersion::from_nonstandard(70016) { + return Err(NegotiationError::UnexpectedMessage( + network_message.command(), + )); + } + self.metadata + .prefers_addrv2 + .store(true, atomic::Ordering::Relaxed); + Ok(NegotiationUpdate::Updated(self)) + } + NetworkMessage::SendHeaders => { + if self.metadata.lowest_common_version < ProtocolVersion::SENDHEADERS_VERSION { + return Err(NegotiationError::UnexpectedMessage( + network_message.command(), + )); + } + self.metadata + .prefers_headers + .store(true, atomic::Ordering::Relaxed); + Ok(NegotiationUpdate::Updated(self)) + } + NetworkMessage::SendCmpct(_) => { + if self.metadata.lowest_common_version < ProtocolVersion::SHORT_IDS_BLOCKS_VERSION { + return Err(NegotiationError::UnexpectedMessage( + network_message.command(), + )); + } + self.metadata + .prefers_cmpct + .store(true, atomic::Ordering::Relaxed); + Ok(NegotiationUpdate::Updated(self)) + } + e => Err(NegotiationError::UnexpectedMessage(e.command())), + } + } +} + +#[derive(Debug)] +enum NegotiationUpdate { + Finished(CompletedHandshake), + Updated(InitializedHandshake), +} + +#[derive(Debug)] +struct CompletedHandshake { + metadata: ConnectionMetadata, + suggested_messages: Vec, +} + +impl CompletedHandshake { + fn take_suggested(&mut self) -> impl Iterator { + core::mem::take(&mut self.suggested_messages).into_iter() + } +} + +// Errors + +#[derive(Debug, Clone)] +enum VersionError { + IrrelevantMessage(CommandString), + ConnectionToSelf, + TooLowVersion(ProtocolVersion), + NotEnoughServices(ServiceFlags), +} + +impl Display for VersionError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::ConnectionToSelf => write!(f, "connected to self."), + Self::TooLowVersion(_) => write!(f, "remote version too low."), + Self::NotEnoughServices(s) => write!(f, "not enough services: {s}"), + Self::IrrelevantMessage(cmd) => write!(f, "unexpected message: {cmd}"), + } + } +} + +impl std::error::Error for VersionError {} + +#[derive(Debug, Clone)] +enum NegotiationError { + UnexpectedMessage(CommandString), +} + +impl Display for NegotiationError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::UnexpectedMessage(c) => write!(f, "unexpected negotiation message {c}"), + } + } +} + +impl std::error::Error for NegotiationError {} + +#[cfg(test)] +mod tests { + use p2p::{ + message::NetworkMessage, + message_network::{UserAgent, VersionMessage}, + ProtocolVersion, ServiceFlags, + }; + + use super::{ConnectionConfig, NegotiationUpdate, Origin, AGENT_VERSION, CLIENT_NAME}; + + #[test] + fn test_complete_successful() { + let connection = ConnectionConfig::new(); + + let version = VersionMessage { + version: ProtocolVersion::WTXID_RELAY_VERSION, + services: ServiceFlags::NONE, + timestamp: 1232132131, + sender: p2p::Address { + services: ServiceFlags::NONE, + address: [0; 8], + port: 0, + }, + receiver: p2p::Address { + services: ServiceFlags::NONE, + address: [0; 8], + port: 0, + }, + start_height: 0, + nonce: 67, + user_agent: UserAgent::new(CLIENT_NAME, AGENT_VERSION), + relay: false, + }; + let version_message = NetworkMessage::Version(version); + let mut initial_handshake = connection + .start_handshake(version_message, Origin::OutBound) + .unwrap(); + let outbound_messages = initial_handshake.take_suggested(); + for message in outbound_messages { + assert!(!matches!(message, NetworkMessage::Alert(_))); + } + let update = initial_handshake + .negotiate(NetworkMessage::WtxidRelay) + .unwrap(); + match update { + NegotiationUpdate::Updated(handshake) => initial_handshake = handshake, + NegotiationUpdate::Finished(_) => panic!("handshake incomplete"), + } + let update = initial_handshake + .negotiate(NetworkMessage::SendAddrV2) + .unwrap(); + match update { + NegotiationUpdate::Updated(handshake) => initial_handshake = handshake, + NegotiationUpdate::Finished(_) => panic!("handshake incomplete"), + } + let update = initial_handshake + .negotiate(NetworkMessage::SendHeaders) + .unwrap(); + match update { + NegotiationUpdate::Updated(handshake) => initial_handshake = handshake, + NegotiationUpdate::Finished(_) => panic!("handshake incomplete"), + } + let update = initial_handshake.negotiate(NetworkMessage::Verack).unwrap(); + match update { + NegotiationUpdate::Updated(_) => panic!("handshake is over"), + NegotiationUpdate::Finished(c) => { + assert!(c + .metadata + .prefers_wtxid + .load(std::sync::atomic::Ordering::Relaxed)); + assert!(c + .metadata + .prefers_headers + .load(std::sync::atomic::Ordering::Relaxed)); + assert!(c + .metadata + .prefers_addrv2 + .load(std::sync::atomic::Ordering::Relaxed)); + } + } + } + + #[test] + fn test_finds_connection_to_self() { + let connection = ConnectionConfig::new(); + + let our_version = VersionMessage { + version: ProtocolVersion::WTXID_RELAY_VERSION, + services: ServiceFlags::NONE, + timestamp: 1232132131, + sender: p2p::Address { + services: ServiceFlags::NONE, + address: [0; 8], + port: 0, + }, + receiver: p2p::Address { + services: ServiceFlags::NONE, + address: [0; 8], + port: 0, + }, + start_height: 0, + nonce: 67, + user_agent: UserAgent::new(CLIENT_NAME, AGENT_VERSION), + relay: false, + }; + let version_message = NetworkMessage::Version(our_version.clone()); + let initial_handshake = connection.start_handshake( + version_message, + Origin::Inbound { + nonce: 67, + version: our_version, + }, + ); + assert!(initial_handshake.is_err()); + } +} diff --git a/p2p/src/lib.rs b/p2p/src/lib.rs index 1a1ca90..7700509 100644 --- a/p2p/src/lib.rs +++ b/p2p/src/lib.rs @@ -15,6 +15,7 @@ use validation::ValidationExt; pub mod net; mod validation; +mod handshake; /// The maximum network message size in bytes. pub const MAX_MESSAGE_SIZE: u32 = 1024 * 1024 * 32;