Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions crates/messages/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,16 @@ pub mod action {
Idle,
BatchCompleted {
batch_size: u32,
batches: u32,
},
WaitedForParameterServer,
SentUpdate {
round: u32,
metrics: HashMap<String, f32>,
},
WaitedForUpdate,
AppliedUpdate,
WaitedForNextRound,
PushedToHub,
SentModel,
ReceivedModel,
Expand Down Expand Up @@ -167,14 +171,23 @@ pub mod action {
ExecuteBatch {
batches: u32,
},
WaitForParameterServer {
timeout: SystemTime,
},
SendUpdate {
target: Reference,
weight: f32,
},
WaitForUpdate {
timeout: SystemTime,
},
ApplyUpdate {
source: Reference,
timeout: SystemTime,
},
WaitForNextRound {
timeout: SystemTime,
},
/// DEPRECATED: Temporary path to push final weights to Hugging Face.
/// Prefer dedicated artifact publishing in future revisions.
PushToHub {
Expand Down
4 changes: 2 additions & 2 deletions crates/scheduler/src/bin/hypha-scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use hypha_scheduler::{
config::Config,
metrics_bridge::{AimConnector, CsvConnector, JsonlConnector, MetricsBridge, OtelConnector},
network::Network,
pool::{Pool, PoolConfig, PoolWithStatistics},
pool::{Pool, PoolConfig, PoolWithWorkerProperties},
scheduler_config::{Job as SchedulerJob, MetricsConfig},
scheduling::{batch_scheduler::BatchScheduler, data_scheduler::DataScheduler},
simulation::BasicSimulation,
Expand Down Expand Up @@ -236,7 +236,7 @@ async fn run(config: ConfigWithMetadata<Config>) -> Result<()> {
grace: Duration::from_millis(diloco_config.resources.worker_pool.grace_ms),
},
);
let worker_pool = PoolWithStatistics::<RunningMean>::new(worker_pool);
let worker_pool = PoolWithWorkerProperties::<RunningMean>::new(worker_pool);
let worker_handle = worker_pool.handle();

let parameter_pool = Pool::new(
Expand Down
2 changes: 1 addition & 1 deletion crates/scheduler/src/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ impl Network {
exclude_cidrs: Vec<IpNet>,
network_config: &NetworkConfig,
) -> Result<(Self, NetworkDriver), SwarmError> {
let (action_sender, action_receiver) = mpsc::channel(64);
let (action_sender, action_receiver) = mpsc::channel(512);
let meter = metrics::global::meter();
let request_timeout =
(Duration::from_millis(network_config.rtt_ms()) * 10).max(Duration::from_secs(10));
Expand Down
133 changes: 84 additions & 49 deletions crates/scheduler/src/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
//! and pursuing a target size.

use std::{
collections::{HashMap, HashSet, hash_map::Entry},
collections::{HashMap, HashSet},
fmt::Display,
future::Future,
pin::Pin,
Expand Down Expand Up @@ -395,68 +395,86 @@ where

/// Descriptor enriched with an optional statistic value.
#[derive(Clone, Debug)]
pub struct WorkerDescriptorWithStats {
pub struct WorkerDescriptorWithProperties {
pub peer_id: PeerId,
pub resources: Resources,
pub last_updated: Option<LastUpdated>,
pub statistic: Option<u64>,
pub state: WorkerState,
}

impl WorkerDescriptorWithStats {
impl WorkerDescriptorWithProperties {
pub fn new(
descriptor: &WorkerDescriptor,
last_updated: Option<LastUpdated>,
statistic: Option<u64>,
state: WorkerState,
) -> Self {
Self {
peer_id: descriptor.peer_id,
resources: descriptor.resources,
last_updated,
statistic,
state,
}
}
}

#[derive(Clone, Debug, Default)]
pub struct WorkerState {
pub sent_update: bool,
pub applied_update: bool,
pub samples_processed: u32,
pub waiting_for_model: bool,
pub receiving_from: Option<PeerId>,
pub is_pusher: bool,
pub push_done: bool,
}

type LastUpdated = u64;
type WorkerProperties<T> = Arc<RwLock<HashMap<PeerId, (LastUpdated, T, WorkerState)>>>;

/// Snapshot view of current members enriched with statistics.
pub struct PoolWithStatistics<T: RuntimeStatistic> {
pub struct PoolWithWorkerProperties<T: RuntimeStatistic> {
pool: Pool,
handle: PoolStatisticsHandle<T>,
handle: PoolWithWorkerPropertiesHandle<T>,
}

impl<T> PoolWithStatistics<T>
impl<T> PoolWithWorkerProperties<T>
where
T: RuntimeStatistic,
{
pub fn new(pool: Pool) -> Self {
let statistics = Arc::new(RwLock::new(HashMap::default()));
let handle = PoolStatisticsHandle {
let properties = Arc::new(RwLock::new(HashMap::default()));
let handle = PoolWithWorkerPropertiesHandle {
pool: pool.handle(),
statistics,
properties,
_marker: std::marker::PhantomData,
};
Self { pool, handle }
}

/// Returns a cloneable handle that exposes statistics and membership without owning the stream.
pub fn handle(&self) -> PoolStatisticsHandle<T> {
pub fn handle(&self) -> PoolWithWorkerPropertiesHandle<T> {
self.handle.clone()
}

/// Returns current members decorated with their latest statistic (if any),
/// after pruning stale statistics.
pub fn statistics(&self) -> Vec<WorkerDescriptorWithStats> {
self.handle.statistics()
pub fn properties(&self) -> Vec<WorkerDescriptorWithProperties> {
self.handle.properties()
}

/// Update statistics for a worker with the given timestamp.
pub fn update(&self, peer_id: &PeerId, now: u64) {
self.handle.update(peer_id, now)
pub fn update_statistics<F>(&self, peer_id: &PeerId, f: F)
where
F: FnOnce(&mut T, &mut u64),
{
self.handle.update_statistics(peer_id, f)
}
}

impl<T> Stream for PoolWithStatistics<T>
impl<T> Stream for PoolWithWorkerProperties<T>
where
T: RuntimeStatistic + std::marker::Unpin,
{
Expand All @@ -467,67 +485,74 @@ where
}
}

impl<T> PoolWithStatistics<T> where T: RuntimeStatistic {}
impl<T> PoolWithWorkerProperties<T> where T: RuntimeStatistic {}

/// Cloneable handle for statistics + membership/dispatchers without owning the stream.
pub struct PoolStatisticsHandle<T: RuntimeStatistic> {
pub struct PoolWithWorkerPropertiesHandle<T: RuntimeStatistic> {
pool: PoolHandle,
statistics: Arc<RwLock<HashMap<PeerId, (LastUpdated, T)>>>,
properties: WorkerProperties<T>,
_marker: std::marker::PhantomData<T>,
}

impl<T: RuntimeStatistic> Clone for PoolStatisticsHandle<T> {
impl<T: RuntimeStatistic> Clone for PoolWithWorkerPropertiesHandle<T> {
fn clone(&self) -> Self {
Self {
pool: self.pool.clone(),
statistics: Arc::clone(&self.statistics),
properties: Arc::clone(&self.properties),
_marker: std::marker::PhantomData,
}
}
}

impl<T> PoolStatisticsHandle<T>
impl<T> PoolWithWorkerPropertiesHandle<T>
where
T: RuntimeStatistic,
{
pub fn statistics(&self) -> Vec<WorkerDescriptorWithStats> {
pub fn properties(&self) -> Vec<WorkerDescriptorWithProperties> {
let members = &self.pool;

let snapshot = members.inner.load_full();
let active: HashSet<PeerId> = snapshot.iter().map(|worker| worker.peer_id).collect();

let mut statistics = self.statistics.write().expect("statistics lock poisoned");
statistics.retain(|peer_id, _| active.contains(peer_id));
let mut properties = self.properties.write().expect("properties lock poisoned");
properties.retain(|peer_id, _| active.contains(peer_id));

snapshot
.iter()
.map(|descriptor| {
let statistic = statistics
.get(&descriptor.peer_id)
.map(|(last_updated, stats)| (*last_updated, stats.value()));

WorkerDescriptorWithStats::new(
let statistic =
properties
.get(&descriptor.peer_id)
.map(|(last_updated, stats, state)| {
(*last_updated, stats.value(), state.clone())
});

WorkerDescriptorWithProperties::new(
descriptor,
statistic.map(|(last_updated, _)| last_updated),
statistic.map(|(_, stat)| stat),
statistic.as_ref().map(|(last_updated, _, _)| *last_updated),
statistic.as_ref().map(|(_, stat, _)| *stat),
statistic.map(|(_, _, state)| state).unwrap_or_default(),
)
})
.collect()
}

pub fn update(&self, peer_id: &PeerId, now: u64) {
let mut statistics = self.statistics.write().expect("statistics lock poisoned");
pub fn update_statistics<F>(&self, peer_id: &PeerId, f: F)
where
F: FnOnce(&mut T, &mut u64),
{
let mut data = self.properties.write().expect("lock poisoned");
let entry = data.entry(*peer_id).or_default();
f(&mut entry.1, &mut entry.0);
}

match statistics.entry(*peer_id) {
Entry::Occupied(mut entry) => {
let (last_updated, stats) = entry.get_mut();
stats.update(now.saturating_sub(*last_updated));
*last_updated = now;
}
Entry::Vacant(entry) => {
entry.insert((now, T::default()));
}
}
pub fn update_state<F>(&self, peer_id: &PeerId, f: F)
where
F: FnOnce(&mut WorkerState),
{
let mut data = self.properties.write().expect("lock poisoned");
let entry = data.entry(*peer_id).or_default();
f(&mut entry.2);
}

pub fn members(&self) -> PoolHandle {
Expand Down Expand Up @@ -764,7 +789,7 @@ mod tests {
let peer_id = worker.peer_id();
let allocator = StubAllocator::new(vec![vec![worker]]);

let mut pool_with_stats = PoolWithStatistics::<RunningMean>::new(Pool::new(
let mut pool_with_stats = PoolWithWorkerProperties::<RunningMean>::new(Pool::new(
allocator,
PoolConfig {
grace: Duration::from_millis(200),
Expand All @@ -783,11 +808,21 @@ mod tests {
.await
.expect("worker should join before timeout");

pool_with_stats.update(&peer_id, 10);
pool_with_stats.update(&peer_id, 25);
pool_with_stats.update_statistics(&peer_id, |stats, last_updated| {
if *last_updated > 0 {
stats.update(10u64.saturating_sub(*last_updated), 1);
}
*last_updated = 10;
});
pool_with_stats.update_statistics(&peer_id, |stats, last_updated| {
if *last_updated > 0 {
stats.update(25u64.saturating_sub(*last_updated), 1);
}
*last_updated = 25;
});

let statistics = pool_with_stats.statistics();
let first = statistics.first().expect("expected member with stats");
let properties = pool_with_stats.properties();
let first = properties.first().expect("expected member with stats");
assert_eq!(first.peer_id, peer_id);
assert_eq!(first.statistic, Some(15));

Expand All @@ -803,6 +838,6 @@ mod tests {
.await
.expect("worker should be removed");

assert!(pool_with_stats.statistics().is_empty());
assert!(pool_with_stats.properties().is_empty());
}
}
4 changes: 2 additions & 2 deletions crates/scheduler/src/scheduler_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ pub struct DiLoCo {
#[serde(rename = "outer_optimizer")]
pub outer_optimizer: Nesterov,
pub resources: DiLoCoResources,
pub model_destination: Option<ModelDestiantion>,
pub model_destination: Option<ModelDestination>,
}

#[derive(Deserialize, Serialize, Debug, Clone, Copy)]
Expand Down Expand Up @@ -179,7 +179,7 @@ impl From<ModelSource> for Model {
}

#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct ModelDestiantion {
pub struct ModelDestination {
pub repository: String,
pub token: String,
}
Expand Down
Loading