diff --git a/shared/client/src/client.rs b/shared/client/src/client.rs index aaba6aafe..0fbfb3088 100644 --- a/shared/client/src/client.rs +++ b/shared/client/src/client.rs @@ -127,6 +127,7 @@ impl + 'sta let peer_manager = Arc::new(PeerManagerHandle::new( MAX_ERRORS_PER_PEER, param_requests_cancel_token.clone(), + p2p.connection_monitor(), )); let mut broadcasts = vec![]; diff --git a/shared/network/src/lib.rs b/shared/network/src/lib.rs index c6531a858..41ffa9fe8 100644 --- a/shared/network/src/lib.rs +++ b/shared/network/src/lib.rs @@ -715,6 +715,10 @@ where } None } + pub fn connection_monitor(&self) -> ConnectionMonitor { + self.connection_monitor.clone() + } + pub fn router(&self) -> Arc { self.router.clone() } diff --git a/shared/network/src/p2p_model_sharing.rs b/shared/network/src/p2p_model_sharing.rs index 87d6271e7..205c20fa6 100644 --- a/shared/network/src/p2p_model_sharing.rs +++ b/shared/network/src/p2p_model_sharing.rs @@ -4,9 +4,9 @@ use iroh::protocol::AcceptError; use iroh::{endpoint::Connection, protocol::ProtocolHandler}; use iroh_blobs::api::Tag; use iroh_blobs::ticket::BlobTicket; -use std::collections::VecDeque; -use std::collections::{HashMap, HashSet, hash_map::Entry}; +use std::collections::{HashMap, HashSet, VecDeque, hash_map::Entry}; use std::io::{Cursor, Write}; +use std::time::Duration; use tch::Tensor; use thiserror::Error; use tokenizers::Tokenizer; @@ -18,8 +18,8 @@ use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; use tracing::{debug, error, info, trace, warn}; +use crate::connection_monitor::ConnectionMonitor; use crate::{NetworkConnection, Networkable, TransmittableDownload}; - #[derive(Debug)] /// Manager for the list of peers to ask for the model parameters and config pub struct PeerManagerHandle { @@ -45,13 +45,18 @@ enum PeerCommand { } impl PeerManagerHandle { - pub fn new(max_errors_per_peer: u8, cancellation_token: CancellationToken) -> Self { + pub fn new( + max_errors_per_peer: u8, + cancellation_token: CancellationToken, + connection_monitor: ConnectionMonitor, + ) -> Self { let (peer_tx, peer_rx) = mpsc::unbounded_channel(); // Spawn the peer manager actor tokio::spawn(peer_manager_actor( peer_rx, max_errors_per_peer, + connection_monitor, cancellation_token, )); @@ -110,27 +115,49 @@ struct PeerManagerActor { errors_per_peers: HashMap, /// Max errors we tolerate for a peer to share a parameter blob ticket max_errors_per_peer: u8, + /// Connection monitor for latency-based peer sorting + connection_monitor: ConnectionMonitor, } impl PeerManagerActor { - pub fn new(max_errors_per_peer: u8) -> Self { + pub fn new(max_errors_per_peer: u8, connection_monitor: ConnectionMonitor) -> Self { Self { available_peers: VecDeque::new(), errors_per_peers: HashMap::new(), max_errors_per_peer, + connection_monitor, } } fn handle_message(&mut self, message: PeerCommand, cancellation_token: CancellationToken) { match message { PeerCommand::SetPeers { peers } => { - self.available_peers = VecDeque::from(peers); - let errors_per_peers_vec = self.available_peers.iter().map(|peer| (*peer, 0_u8)); - self.errors_per_peers = HashMap::from_iter(errors_per_peers_vec); + let mut peers_with_latency: Vec<_> = peers + .into_iter() + .map(|peer| { + let latency = self.connection_monitor.get_latency(&peer); + (peer, latency) + }) + .collect(); + peers_with_latency.sort_by_key(|(_, latency)| latency.unwrap_or(Duration::MAX)); + + self.available_peers = peers_with_latency.iter().map(|(peer, _)| *peer).collect(); + self.errors_per_peers = self + .available_peers + .iter() + .map(|peer| (*peer, 0_u8)) + .collect(); info!( - "Updated peer list: {} peers available to ask for the model parameters", - self.available_peers.len() + "Updated peer list ({} peers) sorted by latency: {:?}", + self.available_peers.len(), + peers_with_latency + .iter() + .map(|(p, l)| ( + p.fmt_short().to_string(), + l.map_or("unknown".into(), |d| format!("{}ms", d.as_millis())) + )) + .collect::>() ); } PeerCommand::GetPeer { reply } => { @@ -189,9 +216,10 @@ impl PeerManagerActor { async fn peer_manager_actor( mut rx: mpsc::UnboundedReceiver, max_errors_per_peer: u8, + connection_monitor: ConnectionMonitor, cancellation_token: CancellationToken, ) { - let mut actor = PeerManagerActor::new(max_errors_per_peer); + let mut actor = PeerManagerActor::new(max_errors_per_peer, connection_monitor); while let Some(message) = rx.recv().await { actor.handle_message(message, cancellation_token.clone());