Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions shared/client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ impl<T: NodeIdentity, A: AuthenticatableIdentity + 'static, B: Backend<T> + 'sta
let peer_manager = Arc::new(PeerManagerHandle::new(
MAX_ERRORS_PER_PEER,
param_requests_cancel_token.clone(),
p2p.connection_monitor(),
));

let mut broadcasts = vec![];
Expand Down
4 changes: 4 additions & 0 deletions shared/network/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,10 @@ where
}
None
}
pub fn connection_monitor(&self) -> ConnectionMonitor {
self.connection_monitor.clone()
}

pub fn router(&self) -> Arc<Router> {
self.router.clone()
}
Expand Down
50 changes: 39 additions & 11 deletions shared/network/src/p2p_model_sharing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ use iroh::protocol::AcceptError;
use iroh::{endpoint::Connection, protocol::ProtocolHandler};
use iroh_blobs::api::Tag;
use iroh_blobs::ticket::BlobTicket;
use std::collections::VecDeque;
use std::collections::{HashMap, HashSet, hash_map::Entry};
use std::collections::{HashMap, HashSet, VecDeque, hash_map::Entry};
use std::io::{Cursor, Write};
use std::time::Duration;
use tch::Tensor;
use thiserror::Error;
use tokenizers::Tokenizer;
Expand All @@ -18,8 +18,8 @@ use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, trace, warn};

use crate::connection_monitor::ConnectionMonitor;
use crate::{NetworkConnection, Networkable, TransmittableDownload};

#[derive(Debug)]
/// Manager for the list of peers to ask for the model parameters and config
pub struct PeerManagerHandle {
Expand All @@ -45,13 +45,18 @@ enum PeerCommand {
}

impl PeerManagerHandle {
pub fn new(max_errors_per_peer: u8, cancellation_token: CancellationToken) -> Self {
pub fn new(
max_errors_per_peer: u8,
cancellation_token: CancellationToken,
connection_monitor: ConnectionMonitor,
) -> Self {
let (peer_tx, peer_rx) = mpsc::unbounded_channel();

// Spawn the peer manager actor
tokio::spawn(peer_manager_actor(
peer_rx,
max_errors_per_peer,
connection_monitor,
cancellation_token,
));

Expand Down Expand Up @@ -110,27 +115,49 @@ struct PeerManagerActor {
errors_per_peers: HashMap<EndpointId, u8>,
/// Max errors we tolerate for a peer to share a parameter blob ticket
max_errors_per_peer: u8,
/// Connection monitor for latency-based peer sorting
connection_monitor: ConnectionMonitor,
}

impl PeerManagerActor {
pub fn new(max_errors_per_peer: u8) -> Self {
pub fn new(max_errors_per_peer: u8, connection_monitor: ConnectionMonitor) -> Self {
Self {
available_peers: VecDeque::new(),
errors_per_peers: HashMap::new(),
max_errors_per_peer,
connection_monitor,
}
}

fn handle_message(&mut self, message: PeerCommand, cancellation_token: CancellationToken) {
match message {
PeerCommand::SetPeers { peers } => {
self.available_peers = VecDeque::from(peers);
let errors_per_peers_vec = self.available_peers.iter().map(|peer| (*peer, 0_u8));
self.errors_per_peers = HashMap::from_iter(errors_per_peers_vec);
let mut peers_with_latency: Vec<_> = peers
.into_iter()
.map(|peer| {
let latency = self.connection_monitor.get_latency(&peer);
(peer, latency)
})
.collect();
peers_with_latency.sort_by_key(|(_, latency)| latency.unwrap_or(Duration::MAX));

self.available_peers = peers_with_latency.iter().map(|(peer, _)| *peer).collect();
self.errors_per_peers = self
.available_peers
.iter()
.map(|peer| (*peer, 0_u8))
.collect();

info!(
"Updated peer list: {} peers available to ask for the model parameters",
self.available_peers.len()
"Updated peer list ({} peers) sorted by latency: {:?}",
self.available_peers.len(),
peers_with_latency
.iter()
.map(|(p, l)| (
p.fmt_short().to_string(),
l.map_or("unknown".into(), |d| format!("{}ms", d.as_millis()))
))
.collect::<Vec<_>>()
);
}
PeerCommand::GetPeer { reply } => {
Expand Down Expand Up @@ -189,9 +216,10 @@ impl PeerManagerActor {
async fn peer_manager_actor(
mut rx: mpsc::UnboundedReceiver<PeerCommand>,
max_errors_per_peer: u8,
connection_monitor: ConnectionMonitor,
cancellation_token: CancellationToken,
) {
let mut actor = PeerManagerActor::new(max_errors_per_peer);
let mut actor = PeerManagerActor::new(max_errors_per_peer, connection_monitor);

while let Some(message) = rx.recv().await {
actor.handle_message(message, cancellation_token.clone());
Expand Down
Loading