diff --git a/core/Cargo.toml b/core/Cargo.toml index ab8e52b..7b9f42f 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -30,14 +30,14 @@ pqc_kyber = { version = "0.7.1", features = ["std", "kyber512"] } rand_chacha = "0.3.1" pqc_dilithium_edit = "0.2.0" async-trait = "0.1.89" -tokio = { version = "1.48.0", features = ["sync", "rt", "macros", "rt-multi-thread", "time", "fs", "io-util", "signal"] } +tokio = { version = "1.48.0", features = ["sync", "rt", "macros", "rt-multi-thread", "time", "fs", "io-util", "signal", "net"] } bytes = { version = "1.11.0", features = ["serde"] } serde = { version = "1.0.228", features = ["derive"] } -postcard = { version = "1.0", features = ["alloc"] } +postcard = { version = "1.0", features = ["alloc", "use-std"] } heapless = "0.9.2" tokio-macros = "2.6.0" aes-gcm = { version= "0.10.3", features = ["rand_core"] } -tokio-util = "0.7.17" +tokio-util = { version = "0.7.17", features = ["join-map", "rt"] } thiserror = "2.0.17" serde-big-array = "0.5.1" tokio-stream = "0.1.18" diff --git a/core/src/message.rs b/core/src/message.rs index ec0b676..9b3e651 100644 --- a/core/src/message.rs +++ b/core/src/message.rs @@ -5,7 +5,7 @@ use crate::{payload::{Payload, Query, Reply}, peer::PeerId}; pub type MsgId = u64; -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, PartialEq)] pub struct IncomingMessage { pub from: PeerId, pub payload: Payload, @@ -21,7 +21,7 @@ impl IncomingMessage { } } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, PartialEq)] pub struct OutgoingMessage { pub to: PeerId, pub payload: Payload, diff --git a/core/src/net/mod.rs b/core/src/net/mod.rs index 80bf267..a00ffaa 100644 --- a/core/src/net/mod.rs +++ b/core/src/net/mod.rs @@ -5,10 +5,12 @@ mod packet; mod types; mod client; mod session_store; +mod service; pub use _session::*; pub use error::*; pub use packet::*; pub use types::*; pub use client::*; -pub use session::*; \ No newline at end of file +pub use session::*; +pub use service::*; \ No newline at end of file diff --git a/core/src/net/service.rs b/core/src/net/service.rs new file mode 100644 index 0000000..ccb2354 --- /dev/null +++ b/core/src/net/service.rs @@ -0,0 +1,67 @@ +use tokio_util::sync::CancellationToken; +use tokio::task::JoinError; +use thiserror::Error; +use crate::{net::{session::ActiveSession,error::{NetError,CryptoError}},message::{IncomingMessage,OutgoingMessage},payload::Payload,peer::{Peer,PeerId},utils::{SerdeError,ChannelError}, transport::{TransportError}}; + +#[derive(Debug,Error)] +pub enum NetServError { + #[error(transparent)] + Serde(#[from] SerdeError), + + #[error(transparent)] + Channel(#[from] ChannelError), + + #[error(transparent)] + Crypto(#[from] CryptoError), + + #[error(transparent)] + Net(#[from] NetError), + + #[error(transparent)] + Join(#[from] JoinError), + + #[error(transparent)] + Transport(#[from] TransportError), + + #[error("client not found")] + ClientNotFound, + + #[error("sessions not found")] + SessionsNotFound, + + #[error("Session already exists")] + SessionAlreadyExists, + + #[error("peer not found")] + PeerNotFound, + +} + + + +pub trait NetService{ + // Interface for handling connections out from a single peer/client + + type Error: Into; + + // Add fully formed sessions to the service + // Deal with handshakes prior to adding into service + // Fail on adding a session if the peer already exists + fn add_session(&mut self, client: (Peer,ActiveSession)) -> impl Future> + Send; + // It is possible that you'd want multiple sessions between two peers. Currently this would be an error. + // If not changed now, it will be hard to fix in the future. + + // Close a session and drop it from the table + fn drop_session(&mut self, peer: &PeerId) -> impl Future> + Send; + + // Listen for incoming messages from all peers + fn listen(&mut self, token: CancellationToken) -> impl Future> + Send; + + // Broadcast messages to all sessions + // Responsible for encrypting for each peer + fn broadcast(&mut self, msg: Payload, token: CancellationToken) -> impl Future> + Send; + + // Transmit messages to a specific session + // OutgoingMessage has its own PeerID + fn transmit(&mut self, msg: Payload,target: PeerId, token: CancellationToken) -> impl Future> + Send; +} \ No newline at end of file diff --git a/core/src/net/session/session.rs b/core/src/net/session/session.rs index f5082c1..659e068 100644 --- a/core/src/net/session/session.rs +++ b/core/src/net/session/session.rs @@ -78,7 +78,7 @@ impl ActiveSession { aad } - pub fn send(&mut self, plaintext: &Vec) -> Result { + pub fn send(&mut self, plaintext: &[u8]) -> Result { self.seq += 1; let cipher = self.cipher(); diff --git a/core/src/payload/dht.rs b/core/src/payload/dht.rs index 384e9f8..140c65c 100644 --- a/core/src/payload/dht.rs +++ b/core/src/payload/dht.rs @@ -3,13 +3,13 @@ use serde::{Deserialize, Serialize}; use crate::{dht::CID, payload::{Query, QueryError, Reply, ReplyError, TryFromQuery, TryFromReply}}; -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, PartialEq,Clone)] pub enum DhtQuery { Get(CID), Put(Bytes), } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, PartialEq,Clone)] pub enum DhtReply { Return(Option), } diff --git a/core/src/payload/mod.rs b/core/src/payload/mod.rs index 72dc564..7de9fb4 100644 --- a/core/src/payload/mod.rs +++ b/core/src/payload/mod.rs @@ -9,20 +9,20 @@ pub use dht::*; use serde::{Deserialize, Serialize}; use thiserror::Error; -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, PartialEq,Clone)] pub enum Payload { Query(Query), Reply(Reply), } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, PartialEq,Clone)] pub enum Query { Pow(PowQuery), Tag(TagQuery), Dht(DhtQuery), } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, PartialEq,Clone)] pub enum Reply { Empty, Ok, @@ -46,7 +46,7 @@ pub enum ReplyError { pub trait TryFromQuery: TryFrom {} pub trait TryFromReply: TryFrom {} -#[derive(Serialize, Deserialize, Clone, Copy, Debug)] +#[derive(Serialize, Deserialize, Clone, Copy, Debug, PartialEq)] #[repr(u16)] pub enum Action { PublishTag = 1, diff --git a/core/src/payload/pow.rs b/core/src/payload/pow.rs index fc063f3..6f4d812 100644 --- a/core/src/payload/pow.rs +++ b/core/src/payload/pow.rs @@ -3,18 +3,18 @@ use thiserror::Error; use crate::{payload::{Action, Query, QueryError, Reply, ReplyError, TryFromQuery, TryFromReply}, pow::Pow}; -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, PartialEq,Clone)] pub enum PowQuery { Get(Action), } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug,PartialEq,Clone)] pub enum PowReply { Require(Pow), Err(PowReplyErr), } -#[derive(Serialize, Deserialize, Debug, Error)] +#[derive(Serialize, Deserialize, Debug, Error, PartialEq,Clone)] pub enum PowReplyErr { #[error("Incorrect nonce")] IncorrectNonce, diff --git a/core/src/payload/tag.rs b/core/src/payload/tag.rs index b01f7ad..49e0673 100644 --- a/core/src/payload/tag.rs +++ b/core/src/payload/tag.rs @@ -2,7 +2,7 @@ use serde::{Deserialize, Serialize}; use crate::{payload::{Query, QueryError, Reply, ReplyError, TryFromQuery, TryFromReply}, pow::Pow, tag::Tag}; -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, PartialEq,Clone)] pub enum TagQuery { Get, Publish { @@ -12,7 +12,7 @@ pub enum TagQuery { }, } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, PartialEq,Clone)] pub enum TagReply { Return(Vec), } diff --git a/core/src/pow.rs b/core/src/pow.rs index f8b1a20..6ebe616 100644 --- a/core/src/pow.rs +++ b/core/src/pow.rs @@ -13,7 +13,7 @@ fn pow_input(secret: &[u8; 32], timestamp: u64, action: Action, random: &[u8; 16 hasher.finalize().into() } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, PartialEq,Clone)] pub struct Pow { pub input: [u8; 32], pub timestamp: u64, diff --git a/core/src/tag/tag.rs b/core/src/tag/tag.rs index 601547c..38fb582 100644 --- a/core/src/tag/tag.rs +++ b/core/src/tag/tag.rs @@ -8,7 +8,7 @@ use bytes::{BufMut, Bytes, BytesMut}; use crate::{VERSION, tag::TagError, utils::{self, deserialize, random_bytes, serialize}}; -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] pub struct Tag { nonce: [u8; 12], content: Vec, diff --git a/core/src/transport/error.rs b/core/src/transport/error.rs index 1f19858..94fbbbd 100644 --- a/core/src/transport/error.rs +++ b/core/src/transport/error.rs @@ -1,6 +1,7 @@ use thiserror::Error; +use tokio::{task::JoinError,sync::mpsc::error::{SendError},net::tcp::OwnedReadHalf}; -use crate::{net::SessionManagerDispatcherError, utils::ChannelError}; +use crate::{net::{SessionManagerDispatcherError,CryptoError,NetError,Message}, utils::ChannelError,peer::{PeerId}}; #[derive(Debug, Error)] pub enum MockTransportError { @@ -15,4 +16,54 @@ pub enum MockTransportError { #[error("peer not found")] PeerNotFound, +} + +#[derive(Debug, Error)] +pub enum TransportError { + // Make PeerId optional so that callers can add information without having to pass PeerId down to every function + #[error("session's read task has been cancelled")] + ReadCancelled(OwnedReadHalf,Option), + + #[error("request was cancelled by token")] + Cancelled, + + // Consider returning peer ID as data in this error + #[error("session not found")] + SessionNotFound(Option), + + // Consider returning peer ID as data in this error + #[error("connection not found in active map")] + ConnectionNotInMap(Option), + + // Consider returning peer ID as data in this error + #[error("no sessions available")] + NoSessions, + + // Consider returning peer ID as data in this error + #[error("message channel has been closed")] + MessageChannelClosed, + + #[error("peer already connected")] + PeerAlreadyConnected(Option), + + #[error("connection closed")] + ConnectionClosed(Option), + + #[error(transparent)] + Serialization(#[from] postcard::Error), + + #[error(transparent)] + Reading(#[from] SendError::<(PeerId,Message)>), + + #[error(transparent)] + IO(#[from] std::io::Error), + + #[error(transparent)] + Join(#[from] JoinError), + + #[error(transparent)] + Encrypt(#[from] CryptoError), + + #[error(transparent)] + Decrypt(#[from] NetError), } \ No newline at end of file diff --git a/core/src/transport/mod.rs b/core/src/transport/mod.rs index 1ae1f7a..7e1bd39 100644 --- a/core/src/transport/mod.rs +++ b/core/src/transport/mod.rs @@ -2,6 +2,7 @@ mod mock; mod error; mod controls; mod participant; +mod tcp; pub use mock::*; pub use error::*; diff --git a/core/src/transport/tcp.rs b/core/src/transport/tcp.rs new file mode 100644 index 0000000..f9d0b09 --- /dev/null +++ b/core/src/transport/tcp.rs @@ -0,0 +1,639 @@ +use tokio_util::{sync::CancellationToken,task::JoinMap}; + +use tokio::{net::{TcpStream, TcpListener, ToSocketAddrs, tcp::{OwnedReadHalf, OwnedWriteHalf}}, sync::{mpsc::{channel,Sender, Receiver}}, io::{AsyncReadExt}}; + +use std::{collections::HashMap}; + +use bytes::BytesMut; + +use postcard::{to_stdvec,take_from_bytes,from_bytes}; + +use crate::{net::{ActiveSession, PendingSession, NetClient,NetService,Message},payload::{Payload}, transport::TransportError, peer::{Peer,PeerId}, message::{IncomingMessage,OutgoingMessage}}; + +const CHANNELSIZE: usize = 128; +// Max message len current 226 +const BUFSIZE: usize = 256; +pub struct TcpTransport{ + incoming_messages: Receiver<(PeerId,Message)>, + message_sender: Sender<(PeerId, Message)>, + active_conns: JoinMap>, + cancel_tokens: HashMap, + sessions: HashMap, + write_streams: HashMap, + client: NetClient, + listener: TcpListener, // Used to establish a new session - not in trait currently - todo +} + +impl TcpTransport { + pub async fn bind(client: NetClient, addr: T) -> Result { + // Create a listener on given IP + let listener = TcpListener::bind(addr).await?; + let (sender,receiver) = channel(CHANNELSIZE); + Ok(TcpTransport { + incoming_messages: receiver, + message_sender: sender, + active_conns: JoinMap::new(), + cancel_tokens: HashMap::new(), + sessions: HashMap::new(), + write_streams: HashMap::new(), + client, + listener, + }) + } +} + +async fn read_from_stream(stream: &mut OwnedReadHalf) -> Result { + let res = loop { + stream.readable().await?; + let mut buf = BytesMut::with_capacity(BUFSIZE); + let res = stream.read_buf(&mut buf).await; + match res { + Ok(0) => { + return Err(TransportError::ConnectionClosed(None)); + }, + Ok(_) => { + break Ok(buf); + } + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => continue, + Err(e) => return Err(e.into()), + } + }; + res +} + +async fn read_message(stream: &mut OwnedReadHalf,id: PeerId, results: &Sender::<(PeerId,Message)>) -> Result<(), TransportError> { + let mut msg_buf = BytesMut::with_capacity(BUFSIZE); + loop { + let data = read_from_stream(stream).await.map_err(|e| { + match e { + // Add peer Id info + TransportError::ConnectionClosed(None) => { + TransportError::ConnectionClosed(Some(id)) + }, + _ => e, + } + })?; + msg_buf.extend_from_slice(&data); + let rem = match take_from_bytes::(&msg_buf) { + Ok((msg,rem)) => { + results.send((id,msg)).await?; + BytesMut::from(rem) + }, + Err(e) => { + match e { + postcard::Error::DeserializeUnexpectedEnd => { + // Buffer not large enough - continue reading + continue + }, + _ => return Err(e.into()), + } + } + }; + // take_from_bytes cannot mutate data so msg_buf will still contain the used + remaining bytes + // replace it with the remaining bytes + msg_buf = rem; + } +} + +impl NetService for TcpTransport { + type Error = TransportError; + + async fn add_session(&mut self, client: (Peer, ActiveSession)) -> Result<(), Self::Error> { + if self.sessions.contains_key(&client.0.id) { + return Err(TransportError::PeerAlreadyConnected(Some(client.0.id))); + } + let (mut reader,writer) = TcpStream::connect(client.0.address).await?.into_split(); + self.sessions.insert(client.0.id, client.1); + self.write_streams.insert(client.0.id,writer); + + let cncl = CancellationToken::new(); + let cncl_task = cncl.clone(); + let send = self.message_sender.clone(); + self.active_conns.spawn(client.0.id,async move { + // Use cancellation token instead of handle abortion to retrieve the OwnedReadHalf on cancel + loop { + tokio::select!{ + _ = cncl_task.cancelled() => return Err(TransportError::ReadCancelled(reader,Some(client.0.id))), + out = read_message(&mut reader,client.0.id,&send) => { + match out { + Ok(_) => {}, + Err(e) => { + return Err(e); + } + } + }, + } + } + }); + self.cancel_tokens.insert(client.0.id, cncl); + Ok(()) + } + + async fn drop_session(&mut self, peer: &PeerId) -> Result<(), Self::Error> { + self.sessions.remove(peer).ok_or(TransportError::SessionNotFound(Some(*peer)))?; + self.write_streams.remove(peer).ok_or(TransportError::SessionNotFound(Some(*peer)))?; + let aborted = self.active_conns.abort(peer); + if !aborted { + return Err(TransportError::ConnectionNotInMap(Some(*peer))) + }; + self.cancel_tokens.remove(peer); + Ok(()) + } + + async fn listen(&mut self, token: CancellationToken) -> Result { + loop { + tokio::select!{ + _ = token.cancelled() => { return Err(TransportError::Cancelled); }, + Some((id,res)) = self.active_conns.join_next() => { + match res { + Ok(r) => { + match r { + Ok(_) => {}, // Session ended gracefully, ignore it + Err(e) => { + match self.drop_session(&id).await { + Ok(_) => {}, + Err(e) => { + match e { + // join_next removes the connection from the map + // drop_session would return an error that we can ignore + // We still want to know if other parts of drop_session fail + TransportError::ConnectionNotInMap(_) => {}, + _ => return Err(e) + } + } + } + return Err(e); + } + } + }, + // We don't want to stop listening if the connection had previously been cancelled and is hanging around + Err(e) if e.is_cancelled() => continue, + Err(e) => return Err(e.into()) + } + } + msg = self.incoming_messages.recv() => { + if let Some((id,encrypted)) = msg { + let session = self.sessions.get_mut(&id).ok_or(TransportError::SessionNotFound(Some(id)))?; + let decrypted = session.receive(encrypted)?; + let outgoing = from_bytes::(&decrypted)?; + return Ok(IncomingMessage::receive(id, outgoing)) + } else { + // Channel has been closed + return Err(TransportError::MessageChannelClosed) + } + } + } + } + } + + async fn broadcast(&mut self, msg: Payload, token: CancellationToken) -> Result<(), Self::Error> { + let keys = self.sessions.keys().map(|k| k.clone()).collect::>(); + // Todo - make this non-serial if poss? Transmit requires mutable borrow of self due to session needing to be mutable. + for peer in keys { + self.transmit(msg.clone(), peer, token.clone()).await?; + } + Ok(()) + } + + async fn transmit(&mut self, msg: Payload, target: PeerId, token: CancellationToken) -> Result<(), Self::Error> { + let session = self.sessions.get_mut(&target).ok_or(TransportError::SessionNotFound(Some(target)))?; + let stream = self.write_streams.get(&target).ok_or(TransportError::SessionNotFound(Some(target)))?; + + let res = tokio::select!{ + _ = token.cancelled() => { + Err(TransportError::Cancelled) + }, + _ = async { + // Serialize once to get a format that can be encrypted + let data = to_stdvec(&msg)?; + let encrypted_data = session.send(&data)?; + // Serialize again to get a sendable stream + let serialized_data = to_stdvec(&encrypted_data)?; + loop { + stream.writable().await?; + match stream.try_write(&serialized_data) { + Ok(_) => break, + Err(ref e) if e.kind() == tokio::io::ErrorKind::WouldBlock => continue, + Err(e) => return Err(e.into()), + }; + } + Ok::<(),TransportError>(()) + } => Ok(()), + }; + + match res { + Ok(()) => Ok(()), + Err(err) => Err(err), + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::{message::OutgoingMessage,payload::{Action, Query, Reply, TagQuery}, pow::Pow, tag::{Tag,TagPayload}}; + use std::{io::Write, sync::mpsc::{Sender, channel}}; + use tokio::time::timeout; + use std::{time::{Duration}}; + use std::io::Read; + use postcard::from_bytes; + + // Timeout for async function calls + const TIMEOUT: Duration = Duration::from_millis(10); + // struct size is 200 + const BUFSIZE: usize = 256; + + // Create a generic transport for testing + async fn ephemeral_transport() -> TcpTransport { + TcpTransport::bind(NetClient::Ephemeral, "127.0.0.1:0").await.unwrap() + } + + async fn static_transport() -> TcpTransport { + TcpTransport::bind(NetClient::from_seed([1u8; 32]), "127.0.0.1:0").await.unwrap() + } + + #[tokio::test] + async fn test_bind() { + // This test needs to use concrete addresses (non port 0) to check it fails correctly + let client = NetClient::Ephemeral; + let addr = "127.0.0.1:30000"; + let transport = TcpTransport::bind(client, addr).await; + + let static_client = NetClient::from_seed([1u8; 32]); + let addr = "127.0.0.1:30030"; + let transport_2 = TcpTransport::bind(static_client, addr).await; + + let client = NetClient::from_seed([1u8; 32]); + let transport_3 = TcpTransport::bind(client, addr).await; + + + assert!(transport.is_ok(), "Bind should succeed with Ephemeral client but got: {}",transport.err().unwrap()); + assert!(transport_2.is_ok(), "Bind should succeed with Static client but got: {}",transport_2.err().unwrap()); + assert!(transport_3.is_err(), "Bind should fail with duplicate address but got success"); + match transport_3.err().unwrap() { + TransportError::IO(_) => {}, + err @ _ => panic!("Expected PeerAlreadyConnected error, got: {}", err) + } + } + + #[tokio::test] + async fn test_add_session() { + let mut transport = ephemeral_transport().await; + let client = NetClient::from_seed([1u8;32]); + + + // Start a listener and move it into a thread after getting the port assigned by the OS. + let srv = TcpListener::bind("127.0.0.1:0").await.expect("Failed to start test server"); + let addr = srv.local_addr().expect("Failed to get local address of thread"); + tokio::spawn(async move { + srv.accept().await + }); + + // First add should succeed - unique Peer ID + let peer = Peer::new(client.identity().expect("Expect static identity"), addr.to_string()); + let pend_session = PendingSession::new([0u8;32],Some(0u64)); + let session = pend_session.activate(None).expect("Failed to active test session"); + let expect_ok = transport.add_session((peer, session)).await; + assert!(expect_ok.is_ok(),"Failed to add session: {}",expect_ok.err().unwrap()); + assert_eq!(transport.sessions.len(), 1); + + // Creating a new Peer with the same peer ID should fail when added. + // See comment in NetService trait about whether this behaviour should be changed. + let peer = Peer::new(client.identity().expect("Expect static identity"), addr.to_string()); + let pend_session = PendingSession::new([1u8;32],Some(1u64)); + let session = pend_session.activate(None).expect("Failed to active test session"); + let expect_fail = transport.add_session((peer,session)).await; + assert!(expect_fail.is_err(),"Expected failure adding second entry with same peer ID"); + } + + #[tokio::test] + async fn test_drop_session() { + let srv = TcpListener::bind("127.0.0.1:0").await.expect("Failed to start test server"); + let addr = srv.local_addr().expect("Failed to get local address of thread"); + + // Create a transport client with 3 sessions + let mut transport = ephemeral_transport().await; + let mut ids = vec!(); + for i in 0..3 { + let client = NetClient::from_seed([i as u8;32]); + let peer = Peer::new(client.identity().expect("Expect static identity"), addr.to_string()); + + let pend_session = PendingSession::new([i as u8;32],Some(i as u64)); + let net_session = pend_session.activate(None).expect("Failed to active test session"); + transport.sessions.insert(peer.id,net_session); + + let (_,write_stream) = TcpStream::connect(addr).await.expect("Failed to connect to test server on {addr}").into_split(); + transport.write_streams.insert(peer.id,write_stream); + transport.active_conns.spawn(peer.id, async {Ok(())}); + transport.cancel_tokens.insert(peer.id, CancellationToken::new()); + + + ids.push(peer.id); + } + + assert_eq!(transport.sessions.len(),3,"Expected 3 elements in starting transport"); + assert_eq!(transport.write_streams.len(),3,"Expected 3 elements in starting transport"); + assert_eq!(transport.active_conns.len(),3,"Expected 3 elements in starting transport"); + + for i in (0..3).rev() { + let expect_ok = transport.drop_session(&ids.pop().unwrap()).await; + assert!(expect_ok.is_ok()); + assert_eq!(transport.sessions.len(),i,"Failed to remove session {i} from sessions"); + assert_eq!(transport.write_streams.len(),i,"Failed to remove session {i} from write_streams"); + } + + let client = NetClient::from_seed([5 as u8;32]); + let peer = Peer::new(client.identity().expect("Expect static identity"), addr.to_string()); + let expect_err = transport.drop_session(&peer.id).await; + + assert!(expect_err.is_err(),"Expected failure removing non-existent session"); + + } + + #[tokio::test] + async fn test_transmit(){ + // Start a listener and move it into a thread after getting the port assigned by the OS. + // Send a channel into the thread to read success/failure of expected values + // Give it a handler function checking expected results and sending true/false over a channel + + // Create a dummy active session that can be used for testing message encryption + let (mut sender_shared_session,mut receiver_shared_session ) = gen_shared_sessions([0u8;32],5u64); + + // Test that shared keys match and encryption/decryption work + let dummy_data = [1u8; 32]; + let encrypt = sender_shared_session.send(&dummy_data).expect("Failed to encrypt dummy data"); + let decrypt = receiver_shared_session.receive(encrypt).expect("Failed to decrypt dummy data"); + assert_eq!(decrypt, dummy_data,"Decrypted dummy data did not match - discontinuing"); + + let client = NetClient::from_seed([1u8;32]); + let (send,results) = channel::(); + let expect_messages = sample_messages(); + let send_messages = expect_messages.clone(); + + let srv = std::net::TcpListener::bind("127.0.0.1:0").expect("Failed to start test server"); + let addr = srv.local_addr().expect("Failed to get local address of thread"); + + std::thread::spawn(move || { + + let (stream, _) = srv.accept().expect("Failed to accept incoming stream"); + expect_message_tcp( + stream, + expect_messages, + &mut receiver_shared_session, + send.clone()); + }); + let mut transport = ephemeral_transport().await; + let peer = Peer::new(client.identity().expect("Expect static identity"), addr.to_string()); + transport.add_session((peer.clone(),sender_shared_session)).await.expect("Failed to add test session to transport"); + for (i,msg) in send_messages.into_iter().enumerate(){ + let cncl = CancellationToken::new(); + let res = timeout( + TIMEOUT, + transport.transmit(msg,peer.id.clone(), cncl) + ).await; + assert!(res.is_ok(),"Failed to transmit message {}: {}",i,res.err().unwrap()); + let got_msg = results.recv().expect("Failed to check receive result"); + assert!(got_msg,"Message {} did not match expected",i); + } + + } + + #[tokio::test] + async fn test_listen(){ + + let mut transport = static_transport().await; + + let mut threads = vec!(); + + let mut expect_messages = vec!(); + // Set up test clients to listen to - send a different message from each one + let sample_messages = sample_messages(); + let num_messages = sample_messages.len(); + for (i,msg) in sample_messages.into_iter().enumerate() { + let (mut sender_session,receiver_session) = gen_shared_sessions([i as u8;32],i as u64 + 12); + + let listener = std::net::TcpListener::bind("127.0.0.1:0").expect("Failed to start test server"); + let addr = listener.local_addr().expect("Failed to get local address of thread"); + + let client = NetClient::from_seed([i as u8;32]); + let peer = Peer::new(client.identity().expect("Expect static ID"), addr.to_string()); + + expect_messages.push( + IncomingMessage{ + from: peer.id.clone(), + payload: msg.clone(), + id: 0, + } + ); + + let send_message = OutgoingMessage{ + to: transport.client.identity().expect("Expected an ID").peer_id(), + payload: msg.clone(), + id: 0, + }; + + transport.add_session((peer,receiver_session)).await.expect("Failed to add test session to transport"); + + let handle = std::thread::spawn(move || { + let (mut stream, _) = listener.accept().expect("Failed to accept incoming stream"); + let msg_bytes = to_stdvec(&send_message).expect("Failed to serialize message"); + let encrypt = sender_session.send(&msg_bytes).expect("Failed to encrypt message"); + let serialized = to_stdvec(&encrypt).expect("Failed to serialize encrypted message"); + stream.write_all(&serialized).expect("Failed to write to stream"); + stream + }); + threads.push(handle); + + + } + + let mut streams = vec!(); + for thread in threads { + streams.push(thread.join().expect("Failed to join thread")); + } + + let mut got_messages = vec!(); + for _ in 0..num_messages { + let cncl = CancellationToken::new(); + let msg = transport.listen(cncl).await.expect("Failed to listen for messages"); + got_messages.push(msg); + } + + for msg in expect_messages { + assert!(got_messages.contains(&msg), "Received message did not match expected" ); + } + + drop(streams); + + } + + #[tokio::test] + async fn test_broadcast(){ + let num_test_clients: usize = 5; + + let mut transport = ephemeral_transport().await; + let test_message = Payload::Query(Query::Tag(TagQuery::Get)); + let (res_sender,res_receiver) = std::sync::mpsc::channel::(); + + let mut threads = vec!(); + // Set up test clients to broadcast to + for i in 1..=num_test_clients { + let (sender_session,mut receiver_session) = gen_shared_sessions([i as u8;32],i as u64 + 12); + + let listener = std::net::TcpListener::bind("127.0.0.1:0").expect("Failed to start test server"); + let addr = listener.local_addr().expect("Failed to get local address of thread"); + + let client = NetClient::from_seed([i as u8;32]); + let peer = Peer::new(client.identity().expect("Expect static ID"), addr.to_string()); + + let res_chan = res_sender.clone(); + + transport.add_session((peer,sender_session)).await.expect("Failed to add test session to transport"); + let handle = std::thread::spawn(move || { + let (stream, _) = listener.accept().expect("Failed to accept incoming stream"); + let expect_messages = vec!(Payload::Query(Query::Tag(TagQuery::Get))); + expect_message_tcp( + stream, + expect_messages, + &mut receiver_session, + res_chan); + }); + threads.push(handle); + } + + // Broadcast the test message - extend the timeout to account for the number of messages + let cncl = CancellationToken::new(); + let res = timeout( + TIMEOUT * num_test_clients as u32, + transport.broadcast(test_message, cncl) + ).await; + + for thread in threads { + thread.join().expect("A thread failed"); + } + + // Close the channel to stop the iterator below hanging + drop(res_sender); + + // Check results + assert!(res.is_ok(),"Failed to broadcast message {}",res.err().unwrap()); + let got_msg = res_receiver.iter().collect::>(); + assert_eq!(got_msg.len(),num_test_clients,"Expected {num_test_clients} but received {} results",got_msg.len()); + assert!(got_msg.iter().all(|x| *x),"Not all messages received"); + } + + // Helper to test receiving messages over TCP + fn expect_message_tcp(mut stream: std::net::TcpStream, expect_messages: Vec, session: &mut ActiveSession, results: Sender) { + for expect_message in expect_messages { + 'receive: loop { + let mut buf = [0u8;BUFSIZE]; + match stream.read(&mut buf) { + Ok(0) => break, + Ok(_) => {}, + Err(ref e) if e.kind() == tokio::io::ErrorKind::WouldBlock => { + continue; + }, + Err(e) => { + eprintln!("Error reading message: {}", e); + results.send(false).expect("Failed to send test failure due to TCP read"); + return + } + } + + match from_bytes::(&buf) { + Err(e) => { + match e { + postcard::Error::SerdeDeCustom => { + // Golden path - deserialization failure because the message is encrypted. + }, + _ => { + // No real reason to end up here. + eprintln!("Error receiving encrypted message: {}", e); + results.send(false).expect("Failed to send test failure"); + } + } + } + Ok(_) => { + // Failure - message was not encrypted + results.send(false).expect("Failed to send test failure due to bad encrypton"); + + } + } + + let encrypted = match from_bytes::(&buf) { + Ok(msg) => { + msg + }, + Err(e) => { + match e { + postcard::Error::DeserializeUnexpectedEnd => { + // Loop again to get more data + continue 'receive; + }, + _ => { + eprintln!("Error reading encrypted message: {}", e); + results.send(false).expect("Failed to send test failure"); + break; + } + } + } + }; + + let decrypted = match session.receive(encrypted) { + Ok(msg) => { + msg + }, + Err(e) => { + eprintln!("Failed to decrypt message: {}", e); + results.send(false).expect("Failed to send test failure due to decryption"); + break; + } + }; + + match from_bytes::(&decrypted) { + Ok(msg) => { + if msg == expect_message { + // Golden path + results.send(true).expect("Failed to send test success"); + break; + } else { + // Bad deserialization + eprintln!("Bad message. Got: \n{msg:?}\n expected: \n{expect_message:?}"); + results.send(false).expect("Failed to send test failure due to deserialization"); + break; + } + }, + Err(e) => { + eprintln!("Error deserializing decrypted message: {}", e); + results.send(false).expect("Failed to send test failure due to deserialization after decryption"); + break; + } + } + } + } + } + + // Helper to generate sample messages for testing + fn sample_messages() -> Vec { + let first_message = Payload::Query(Query::Tag(TagQuery::Get)); + let second_message = Payload::Reply(Reply::Ok); + let (test_tag, test_pow) = {( + Tag::new(&[7u8;32],TagPayload{data:vec!()}).expect("Failed to generate test tag"), + Pow::new(&[4u8;32],Action::PublishTag, 213u8) + )}; + let third_message = Payload::Query(Query::Tag(TagQuery::Publish{tag: test_tag, pow: test_pow, nonce: 17u64})); + let fourth_message = Payload::Reply(Reply::Ok); + vec!( + first_message, + second_message, + third_message, + fourth_message + ) + } + + fn gen_shared_sessions(shared: [u8;32], conn_id: u64) -> (ActiveSession, ActiveSession) { + let sender_shared_session: ActiveSession = PendingSession::new(shared.clone(),Some(conn_id)).activate(None).expect("Failed to create shared test session"); + let receiver_shared_session= PendingSession::new(shared,Some(conn_id)).activate(None).expect("Failed to create shared test session"); + (sender_shared_session, receiver_shared_session) + } +} \ No newline at end of file