From 246102d0ff4c0eb8e51ab41b925010bea701d92c Mon Sep 17 00:00:00 2001 From: Orlando Hohmeier Date: Mon, 29 Dec 2025 15:23:26 +0100 Subject: [PATCH 1/5] fix(accelerate): only configure OTEL if endpoint is defined --- .../src/hypha/accelerate_executor/training.py | 40 ++++++++++--------- 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/executors/accelerate/src/hypha/accelerate_executor/training.py b/executors/accelerate/src/hypha/accelerate_executor/training.py index 733b5d53..9c295db2 100644 --- a/executors/accelerate/src/hypha/accelerate_executor/training.py +++ b/executors/accelerate/src/hypha/accelerate_executor/training.py @@ -38,37 +38,41 @@ CURRENT_MODEL_NAME = "global_weights.pt" MIN_LOOP_TIME_MS = 100 -otel_handler = None -if "OTEL" in os.environ: +# NOTE: Set the root logger level to NOTSET to ensure all messages are captured +# and attach console and OTEL (if configured) handlers to root logger +logging.getLogger().setLevel(logging.NOTSET) + +# NOTE: Set level for httpx and httpcore to WARNING to reduce noise +logging.getLogger("httpx").setLevel(logging.WARNING) +logging.getLogger("httpcore").setLevel(logging.WARNING) + +console_handler = logging.StreamHandler(sys.stdout) +console_handler.setLevel(logging.INFO) +console_handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s")) +logging.getLogger().addHandler(console_handler) + + +# NOTE: Only configure OTEL exporters if endpoint is defined. +# If no endpoint is configured, skip exporters. +otel_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT") +if otel_endpoint: resource = get_aggregated_resources([OTELResourceDetector()]) - # Configure OTEL + exporter = OTLPLogExporter() logger_provider = LoggerProvider(resource=resource) - print(logger_provider) + logger_provider.add_log_record_processor(BatchLogRecordProcessor(exporter)) set_logger_provider(logger_provider) - exporter = OTLPLogExporter() - logger_provider.add_log_record_processor(BatchLogRecordProcessor(exporter)) otel_handler = LoggingHandler(level=logging.NOTSET, logger_provider=logger_provider) + logging.getLogger().addHandler(otel_handler) metric_exporter = OTLPMetricExporter() metric_reader = PeriodicExportingMetricReader(metric_exporter) meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader]) metrics.set_meter_provider(meter_provider) - SystemMetricsInstrumentor().instrument(meter_provider=meter_provider) -# NOTE: Set the root logger level to NOTSET to ensure all messages are captured -# and attach OTLP + console handlers to root logger -console_handler = logging.StreamHandler(sys.stdout) -console_handler.setLevel(logging.INFO) -console_handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s")) + SystemMetricsInstrumentor().instrument(meter_provider=meter_provider) -logging.getLogger("httpx").setLevel(logging.WARNING) -logging.getLogger("httpcore").setLevel(logging.WARNING) -logging.getLogger().setLevel(logging.NOTSET) -if otel_handler: - logging.getLogger().addHandler(otel_handler) -logging.getLogger().addHandler(console_handler) logger = logging.getLogger(__name__) From 7ed9953204e56aae54ff34309448c9d267e73e39 Mon Sep 17 00:00:00 2001 From: Orlando Hohmeier Date: Mon, 29 Dec 2025 16:09:20 +0100 Subject: [PATCH 2/5] fix(network): increase action buffer --- crates/scheduler/src/network.rs | 2 +- crates/worker/src/network.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/scheduler/src/network.rs b/crates/scheduler/src/network.rs index bb9eff69..f6c1c474 100644 --- a/crates/scheduler/src/network.rs +++ b/crates/scheduler/src/network.rs @@ -110,7 +110,7 @@ impl Network { exclude_cidrs: Vec, network_config: &NetworkConfig, ) -> Result<(Self, NetworkDriver), SwarmError> { - let (action_sender, action_receiver) = mpsc::channel(64); + let (action_sender, action_receiver) = mpsc::channel(512); let meter = metrics::global::meter(); let request_timeout = (Duration::from_millis(network_config.rtt_ms()) * 10).max(Duration::from_secs(10)); diff --git a/crates/worker/src/network.rs b/crates/worker/src/network.rs index 49dddb68..c25cb61d 100644 --- a/crates/worker/src/network.rs +++ b/crates/worker/src/network.rs @@ -104,7 +104,7 @@ impl Network { exclude_cidrs: Vec, network_config: &NetworkConfig, ) -> Result<(Self, NetworkDriver), SwarmError> { - let (action_sender, action_receiver) = mpsc::channel(64); + let (action_sender, action_receiver) = mpsc::channel(512); let meter = metrics::global::meter(); let request_timeout = (Duration::from_millis(network_config.rtt_ms()) * 10).max(Duration::from_secs(10)); From 8af0de67c8c317b5a62ab539353964720599455f Mon Sep 17 00:00:00 2001 From: Orlando Hohmeier Date: Mon, 29 Dec 2025 18:29:50 +0100 Subject: [PATCH 3/5] fix: centralize worker state updates to avoid sync race Refactor the batch scheduler and worker pool so worker state is updated in a single place. This prevents duplicate updates and resolves worker sync startup issues resulting in training failures. --- crates/scheduler/src/bin/hypha-scheduler.rs | 4 +- crates/scheduler/src/pool.rs | 134 +- .../src/scheduling/batch_scheduler.rs | 1228 ++++++++++++++--- 3 files changed, 1115 insertions(+), 251 deletions(-) diff --git a/crates/scheduler/src/bin/hypha-scheduler.rs b/crates/scheduler/src/bin/hypha-scheduler.rs index 6fc8ec4c..3f03e0f6 100644 --- a/crates/scheduler/src/bin/hypha-scheduler.rs +++ b/crates/scheduler/src/bin/hypha-scheduler.rs @@ -24,7 +24,7 @@ use hypha_scheduler::{ config::Config, metrics_bridge::{AimConnector, CsvConnector, JsonlConnector, MetricsBridge, OtelConnector}, network::Network, - pool::{Pool, PoolConfig, PoolWithStatistics}, + pool::{Pool, PoolConfig, PoolWithWorkerProperties}, scheduler_config::{Job as SchedulerJob, MetricsConfig}, scheduling::{batch_scheduler::BatchScheduler, data_scheduler::DataScheduler}, simulation::BasicSimulation, @@ -236,7 +236,7 @@ async fn run(config: ConfigWithMetadata) -> Result<()> { grace: Duration::from_millis(diloco_config.resources.worker_pool.grace_ms), }, ); - let worker_pool = PoolWithStatistics::::new(worker_pool); + let worker_pool = PoolWithWorkerProperties::::new(worker_pool); let worker_handle = worker_pool.handle(); let parameter_pool = Pool::new( diff --git a/crates/scheduler/src/pool.rs b/crates/scheduler/src/pool.rs index 3b007efe..e7c7e1d3 100644 --- a/crates/scheduler/src/pool.rs +++ b/crates/scheduler/src/pool.rs @@ -4,7 +4,7 @@ //! and pursuing a target size. use std::{ - collections::{HashMap, HashSet, hash_map::Entry}, + collections::{HashMap, HashSet}, fmt::Display, future::Future, pin::Pin, @@ -395,68 +395,87 @@ where /// Descriptor enriched with an optional statistic value. #[derive(Clone, Debug)] -pub struct WorkerDescriptorWithStats { +pub struct WorkerDescriptorWithProperties { pub peer_id: PeerId, pub resources: Resources, pub last_updated: Option, pub statistic: Option, + pub state: WorkerState, } -impl WorkerDescriptorWithStats { +impl WorkerDescriptorWithProperties { pub fn new( descriptor: &WorkerDescriptor, last_updated: Option, statistic: Option, + state: WorkerState, ) -> Self { Self { peer_id: descriptor.peer_id, resources: descriptor.resources, last_updated, statistic, + state, } } } +#[derive(Clone, Debug, Default)] +pub struct WorkerState { + pub sent_update: bool, + pub applied_update: bool, + pub applied_final_update: bool, + pub samples_processed: u32, + pub waiting_for_model: bool, + pub receiving_from: Option, + pub is_pusher: bool, + pub push_done: bool, +} + type LastUpdated = u64; +type WorkerProperties = Arc>>; /// Snapshot view of current members enriched with statistics. -pub struct PoolWithStatistics { +pub struct PoolWithWorkerProperties { pool: Pool, - handle: PoolStatisticsHandle, + handle: PoolWithWorkerPropertiesHandle, } -impl PoolWithStatistics +impl PoolWithWorkerProperties where T: RuntimeStatistic, { pub fn new(pool: Pool) -> Self { - let statistics = Arc::new(RwLock::new(HashMap::default())); - let handle = PoolStatisticsHandle { + let properties = Arc::new(RwLock::new(HashMap::default())); + let handle = PoolWithWorkerPropertiesHandle { pool: pool.handle(), - statistics, + properties, _marker: std::marker::PhantomData, }; Self { pool, handle } } /// Returns a cloneable handle that exposes statistics and membership without owning the stream. - pub fn handle(&self) -> PoolStatisticsHandle { + pub fn handle(&self) -> PoolWithWorkerPropertiesHandle { self.handle.clone() } /// Returns current members decorated with their latest statistic (if any), /// after pruning stale statistics. - pub fn statistics(&self) -> Vec { - self.handle.statistics() + pub fn properties(&self) -> Vec { + self.handle.properties() } /// Update statistics for a worker with the given timestamp. - pub fn update(&self, peer_id: &PeerId, now: u64) { - self.handle.update(peer_id, now) + pub fn update_statistics(&self, peer_id: &PeerId, f: F) + where + F: FnOnce(&mut T, &mut u64), + { + self.handle.update_statistics(peer_id, f) } } -impl Stream for PoolWithStatistics +impl Stream for PoolWithWorkerProperties where T: RuntimeStatistic + std::marker::Unpin, { @@ -467,67 +486,74 @@ where } } -impl PoolWithStatistics where T: RuntimeStatistic {} +impl PoolWithWorkerProperties where T: RuntimeStatistic {} /// Cloneable handle for statistics + membership/dispatchers without owning the stream. -pub struct PoolStatisticsHandle { +pub struct PoolWithWorkerPropertiesHandle { pool: PoolHandle, - statistics: Arc>>, + properties: WorkerProperties, _marker: std::marker::PhantomData, } -impl Clone for PoolStatisticsHandle { +impl Clone for PoolWithWorkerPropertiesHandle { fn clone(&self) -> Self { Self { pool: self.pool.clone(), - statistics: Arc::clone(&self.statistics), + properties: Arc::clone(&self.properties), _marker: std::marker::PhantomData, } } } -impl PoolStatisticsHandle +impl PoolWithWorkerPropertiesHandle where T: RuntimeStatistic, { - pub fn statistics(&self) -> Vec { + pub fn properties(&self) -> Vec { let members = &self.pool; let snapshot = members.inner.load_full(); let active: HashSet = snapshot.iter().map(|worker| worker.peer_id).collect(); - let mut statistics = self.statistics.write().expect("statistics lock poisoned"); - statistics.retain(|peer_id, _| active.contains(peer_id)); + let mut properties = self.properties.write().expect("properties lock poisoned"); + properties.retain(|peer_id, _| active.contains(peer_id)); snapshot .iter() .map(|descriptor| { - let statistic = statistics - .get(&descriptor.peer_id) - .map(|(last_updated, stats)| (*last_updated, stats.value())); - - WorkerDescriptorWithStats::new( + let statistic = + properties + .get(&descriptor.peer_id) + .map(|(last_updated, stats, state)| { + (*last_updated, stats.value(), state.clone()) + }); + + WorkerDescriptorWithProperties::new( descriptor, - statistic.map(|(last_updated, _)| last_updated), - statistic.map(|(_, stat)| stat), + statistic.as_ref().map(|(last_updated, _, _)| *last_updated), + statistic.as_ref().map(|(_, stat, _)| *stat), + statistic.map(|(_, _, state)| state).unwrap_or_default(), ) }) .collect() } - pub fn update(&self, peer_id: &PeerId, now: u64) { - let mut statistics = self.statistics.write().expect("statistics lock poisoned"); + pub fn update_statistics(&self, peer_id: &PeerId, f: F) + where + F: FnOnce(&mut T, &mut u64), + { + let mut data = self.properties.write().expect("lock poisoned"); + let entry = data.entry(*peer_id).or_default(); + f(&mut entry.1, &mut entry.0); + } - match statistics.entry(*peer_id) { - Entry::Occupied(mut entry) => { - let (last_updated, stats) = entry.get_mut(); - stats.update(now.saturating_sub(*last_updated)); - *last_updated = now; - } - Entry::Vacant(entry) => { - entry.insert((now, T::default())); - } - } + pub fn update_state(&self, peer_id: &PeerId, f: F) + where + F: FnOnce(&mut WorkerState), + { + let mut data = self.properties.write().expect("lock poisoned"); + let entry = data.entry(*peer_id).or_default(); + f(&mut entry.2); } pub fn members(&self) -> PoolHandle { @@ -764,7 +790,7 @@ mod tests { let peer_id = worker.peer_id(); let allocator = StubAllocator::new(vec![vec![worker]]); - let mut pool_with_stats = PoolWithStatistics::::new(Pool::new( + let mut pool_with_stats = PoolWithWorkerProperties::::new(Pool::new( allocator, PoolConfig { grace: Duration::from_millis(200), @@ -783,11 +809,21 @@ mod tests { .await .expect("worker should join before timeout"); - pool_with_stats.update(&peer_id, 10); - pool_with_stats.update(&peer_id, 25); + pool_with_stats.update_statistics(&peer_id, |stats, last_updated| { + if *last_updated > 0 { + stats.update(10u64.saturating_sub(*last_updated)); + } + *last_updated = 10; + }); + pool_with_stats.update_statistics(&peer_id, |stats, last_updated| { + if *last_updated > 0 { + stats.update(25u64.saturating_sub(*last_updated)); + } + *last_updated = 25; + }); - let statistics = pool_with_stats.statistics(); - let first = statistics.first().expect("expected member with stats"); + let properties = pool_with_stats.properties(); + let first = properties.first().expect("expected member with stats"); assert_eq!(first.peer_id, peer_id); assert_eq!(first.statistic, Some(15)); @@ -803,6 +839,6 @@ mod tests { .await .expect("worker should be removed"); - assert!(pool_with_stats.statistics().is_empty()); + assert!(pool_with_stats.properties().is_empty()); } } diff --git a/crates/scheduler/src/scheduling/batch_scheduler.rs b/crates/scheduler/src/scheduling/batch_scheduler.rs index 0056fb82..8b88afe5 100644 --- a/crates/scheduler/src/scheduling/batch_scheduler.rs +++ b/crates/scheduler/src/scheduling/batch_scheduler.rs @@ -1,5 +1,4 @@ use std::{ - collections::{HashMap, HashSet, hash_map::Entry}, sync::Arc, time::{Duration, Instant, SystemTime}, }; @@ -28,7 +27,7 @@ use uuid::Uuid; use crate::{ metrics_bridge::Metrics, network::Network, - pool::{PoolHandle, PoolStatisticsHandle}, + pool::{PoolHandle, PoolWithWorkerPropertiesHandle}, scheduler_config::ModelDestiantion, simulation::Simulation, statistics::RuntimeStatistic, @@ -38,7 +37,6 @@ use crate::{ // decide when to instruct the parameter server to aggregate. struct RoundState { aggregated_updates: bool, - sent_updates: HashSet, first_update_at: Option, round_started_at: Instant, aggregate_started_at: Option, @@ -46,19 +44,14 @@ struct RoundState { grace: Duration, round: u32, update_rounds: u32, - push_assigned: Option, training_complete: bool, - applied_final_update: HashSet, push_done: bool, - // NOTE: Tracks workers that have applied the update for the current round. - applied_updates: HashSet, } impl Default for RoundState { fn default() -> Self { Self { aggregated_updates: false, - sent_updates: HashSet::new(), first_update_at: None, round_started_at: Instant::now(), aggregate_started_at: None, @@ -66,11 +59,8 @@ impl Default for RoundState { grace: Duration::from_millis(0), round: 0, update_rounds: 0, - push_assigned: None, training_complete: false, - applied_final_update: HashSet::new(), push_done: false, - applied_updates: HashSet::new(), } } } @@ -79,9 +69,6 @@ impl Default for RoundState { struct TrainingState { update_target: u32, counter: u32, - peer_updates: HashMap, - worker_without_model: Vec, - receive_from: HashMap, } impl TrainingState { @@ -89,23 +76,11 @@ impl TrainingState { Self { update_target, counter: 0, - peer_updates: HashMap::new(), - worker_without_model: vec![], - receive_from: HashMap::new(), } } - fn record_batch(&mut self, batch_size: u32, peer_id: PeerId) { + fn record_batch(&mut self, batch_size: u32) { self.counter = self.counter.saturating_add(batch_size); - match self.peer_updates.entry(peer_id) { - Entry::Occupied(mut entry) => { - let processed = entry.get(); - entry.insert(processed.saturating_add(batch_size)); - } - Entry::Vacant(entry) => { - entry.insert(batch_size); - } - } } fn get_count(&self) -> u32 { @@ -118,31 +93,6 @@ impl TrainingState { fn reset_round(&mut self) { self.counter = 0; - self.peer_updates = HashMap::new(); - } - - fn get_peer_updates(&self, peer_id: &PeerId) -> u32 { - *self.peer_updates.get(peer_id).unwrap_or(&0u32) - } - - fn pop_worker_without_model(&mut self) -> Option { - self.worker_without_model.pop() - } - - fn push_worker_without_model(&mut self, peer_id: PeerId) { - self.worker_without_model.push(peer_id); - } - - fn remove_receive_from(&mut self, peer_id: &PeerId) -> Option { - self.receive_from.remove(peer_id) - } - - fn insert_receive_from(&mut self, source: PeerId, destination: PeerId) { - self.receive_from.insert(destination, source); - } - - fn get_waiting_workers(&self) -> &[PeerId] { - &self.worker_without_model[..] } } @@ -166,7 +116,7 @@ pub enum BatchSchedulerError { #[allow(clippy::too_many_arguments)] async fn schedule( tx: Sender<(PeerId, Metrics)>, - worker_pool: PoolStatisticsHandle, + worker_pool: PoolWithWorkerPropertiesHandle, parameter_pool: PoolHandle, round_state: Arc>, training_state: Arc>, @@ -206,19 +156,21 @@ where timeout: short_idle, }) } else { - training_state - .lock() - .await - .push_worker_without_model(peer_id); + worker_pool.update_state(&peer_id, |s| s.waiting_for_model = true); ExecutorAction::Train(TrainAction::WaitForModel { timeout: wait_model, }) } } TrainStatus::WaitedForModel => { - if let Some(sending_peer) = - training_state.lock().await.remove_receive_from(&peer_id) - { + let sending_peer = { + let snapshot = worker_pool.properties(); + let worker = snapshot.iter().find(|w| w.peer_id == peer_id); + worker.and_then(|w| w.state.receiving_from) + }; + + if let Some(sending_peer) = sending_peer { + worker_pool.update_state(&peer_id, |s| s.receiving_from = None); ExecutorAction::Train(TrainAction::ReceiveModel { source: Reference::Peers { peers: vec![sending_peer], @@ -244,19 +196,22 @@ where TrainStatus::Idle => { let mut state = round_state.lock().await; if !state.training_complete { - let snapshot = worker_pool.statistics(); + let snapshot = worker_pool.properties(); let peer_position = snapshot .iter() .position(|w| w.peer_id == peer_id) .unwrap_or(0); - let (count, update_target, peer_contribution) = { + // Get peer contribution from pool + let peer_contribution = snapshot + .iter() + .find(|w| w.peer_id == peer_id) + .map(|w| w.state.samples_processed) + .unwrap_or(0); + + let (count, update_target) = { let training = training_state.lock().await; - ( - training.get_count(), - training.get_update_target(), - training.get_peer_updates(&peer_id), - ) + (training.get_count(), training.get_update_target()) }; let stats: Vec = @@ -270,8 +225,8 @@ where .map(|w| (batch_sizer)(&w.resources)) .collect(); - let (should_update, projected_target, batches) = if update_target <= count { - (true, count, 0) + let (should_update, projected_target) = if update_target <= count { + (true, count) } else if !snapshot.is_empty() && batch_sizes.iter().all(|&b| b > 0) && stats.iter().all(|&s| s > 0 && s < u64::MAX) @@ -288,7 +243,6 @@ where time = %time, count = %cnt, peer = %peer_id, - capped, "Simulation with projection {:?} and {:?}", projection, update_target.saturating_sub(count) @@ -299,13 +253,25 @@ where && peer_position < projection.len() && projection[peer_position] == 0, count.saturating_add(cnt.unsigned_abs()), - projection[peer_position], ) } else { - (false, count, multi_batch_size) + (false, count) }; - if state.aggregated_updates && !state.applied_updates.contains(&peer_id) { + // Check if peer has applied update or sent update + let (has_applied_update, has_sent_update, _applied_final_update) = snapshot + .iter() + .find(|w| w.peer_id == peer_id) + .map(|w| { + ( + w.state.applied_update, + w.state.sent_update, + w.state.applied_final_update, + ) + }) + .unwrap_or((false, false, false)); + + if state.aggregated_updates && !has_applied_update { ExecutorAction::Train(TrainAction::ApplyUpdate { source: Reference::Peers { peers: parameter_servers, @@ -314,12 +280,14 @@ where }, timeout: now + Duration::from_secs(10), }) - } else if state.sent_updates.contains(&peer_id) { + } else if has_sent_update { ExecutorAction::Train(TrainAction::Idle { timeout: short_idle, }) } else if !should_update { - ExecutorAction::Train(TrainAction::ExecuteBatch { batches }) + ExecutorAction::Train(TrainAction::ExecuteBatch { + batches: multi_batch_size, + }) } else if parameter_servers.is_empty() { // NOTE: If we need to send an update but there are no parameter servers, // we must wait (idle) until one becomes available. @@ -340,51 +308,91 @@ where } else if state.push_done { cancel.cancel(); ExecutorAction::Train(TrainAction::Terminate) - } else if state.push_assigned.is_none() - && state.applied_final_update.contains(&peer_id) - { - state.push_assigned = Some(peer_id); - if let Some(destination) = push_destination.as_ref().as_ref() { - ExecutorAction::Train(TrainAction::PushToHub { - repository: destination.repository.clone(), - token: destination.token.clone(), + } else { + let snapshot = worker_pool.properties(); + let has_push_assignment = snapshot.iter().any(|w| w.state.is_pusher); + let (applied_final_update, _am_pusher, has_applied_update) = snapshot + .iter() + .find(|w| w.peer_id == peer_id) + .map(|w| { + ( + w.state.applied_final_update, + w.state.is_pusher, + w.state.applied_update, + ) + }) + .unwrap_or((false, false, false)); + + if state.aggregated_updates && !has_applied_update { + ExecutorAction::Train(TrainAction::ApplyUpdate { + source: Reference::Peers { + peers: parameter_servers, + strategy: SelectionStrategy::All, + resource: None, + }, + timeout: now + Duration::from_secs(10), }) + } else if !has_push_assignment && applied_final_update { + // Assign self as pusher + worker_pool.update_state(&peer_id, |s| s.is_pusher = true); + + if let Some(destination) = push_destination.as_ref().as_ref() { + ExecutorAction::Train(TrainAction::PushToHub { + repository: destination.repository.clone(), + token: destination.token.clone(), + }) + } else { + // Should not occur due to guard above. + state.push_done = true; + ExecutorAction::Train(TrainAction::Terminate) + } } else { - // Should not occur due to guard above. - state.push_done = true; - ExecutorAction::Train(TrainAction::Terminate) + ExecutorAction::Train(TrainAction::Idle { + timeout: short_idle, + }) } - } else { - ExecutorAction::Train(TrainAction::Idle { - timeout: short_idle, - }) } } TrainStatus::BatchCompleted { batch_size } => { - worker_pool.update(&peer_id, since_start); + worker_pool.update_statistics(&peer_id, |stats, last_updated| { + if *last_updated > 0 { + stats.update(since_start.saturating_sub(*last_updated)); + } + *last_updated = since_start; + }); - let snapshot = worker_pool.statistics(); + let snapshot = worker_pool.properties(); let peer_position = snapshot .iter() .position(|w| w.peer_id == peer_id) .unwrap_or(0); - let (count, update_target, peer_contribution) = { + let (count, update_target) = { let mut training = training_state.lock().await; - training.record_batch(batch_size, peer_id); - ( - training.get_count(), - training.get_update_target(), - training.get_peer_updates(&peer_id), - ) + training.record_batch(batch_size); + (training.get_count(), training.get_update_target()) }; + // Update per-worker samples + worker_pool.update_state(&peer_id, |s| { + s.samples_processed = s.samples_processed.saturating_add(batch_size); + }); + + // Get fresh snapshot for peer contribution after update + let peer_contribution = snapshot + .iter() + .find(|w| w.peer_id == peer_id) + .map(|w| w.state.samples_processed.saturating_add(batch_size)) + .unwrap_or(batch_size); + let (training_complete, sent_update) = { let state = round_state.lock().await; - ( - state.training_complete, - state.sent_updates.contains(&peer_id), - ) + let sent = snapshot + .iter() + .find(|w| w.peer_id == peer_id) + .map(|w| w.state.sent_update) + .unwrap_or(false); + (state.training_complete, sent) }; if training_complete || sent_update { @@ -460,10 +468,12 @@ where } } TrainStatus::SentUpdate { mut metrics, .. } => { - let worker_samples = { - let training = training_state.lock().await; - training.get_peer_updates(&peer_id) as f32 - }; + let snapshot = worker_pool.properties(); + let worker_samples = snapshot + .iter() + .find(|w| w.peer_id == peer_id) + .map(|w| w.state.samples_processed) + .unwrap_or(0) as f32; let (round, round_started_at) = { let state = round_state.lock().await; @@ -485,14 +495,17 @@ where tx.send((peer_id, Metrics { round, metrics })) .await .map_err(BatchSchedulerError::from)?; + // NOTE: Track workers that have sent their update for the current round. + worker_pool.update_state(&peer_id, |s| s.sent_update = true); + let snapshot = worker_pool.properties(); // Refresh after update + let mut state = round_state.lock().await; - state.sent_updates.insert(peer_id); if state.first_update_at.is_none() { state.first_update_at = Some(Instant::now()); } - let total_workers = worker_pool.statistics().len(); - let sent = state.sent_updates.len(); + let total_workers = snapshot.len(); + let sent = snapshot.iter().filter(|w| w.state.sent_update).count(); let elapsed_ms = state .first_update_at .map(|t| t.elapsed().as_millis() as u64) @@ -514,11 +527,11 @@ where } TrainStatus::AppliedUpdate => { let training_complete = { - let mut state = round_state.lock().await; - state.applied_updates.insert(peer_id); + worker_pool.update_state(&peer_id, |s| s.applied_update = true); + let state = round_state.lock().await; if state.training_complete { - state.applied_final_update.insert(peer_id); + worker_pool.update_state(&peer_id, |s| s.applied_final_update = true); true } else { false @@ -530,9 +543,20 @@ where timeout: now + Duration::from_millis(500), }) } else { - let mut training = training_state.lock().await; - if let Some(update_worker) = training.pop_worker_without_model() { - training.insert_receive_from(peer_id, update_worker); + // Find a worker waiting for model + let snapshot = worker_pool.properties(); + let waiting_worker = snapshot + .iter() + .find(|w| w.state.waiting_for_model) + .map(|w| w.peer_id); + + if let Some(update_worker) = waiting_worker { + // Mark them as not waiting and set receiving from + worker_pool.update_state(&update_worker, |s| { + s.waiting_for_model = false; + s.receiving_from = Some(peer_id); + }); + ExecutorAction::Train(TrainAction::SendModel { target: Reference::Peers { peers: vec![update_worker], @@ -550,8 +574,15 @@ where } } TrainStatus::PushedToHub => { - let mut state = round_state.lock().await; - if state.push_assigned == Some(peer_id) { + let is_pusher = worker_pool + .properties() + .iter() + .find(|w| w.peer_id == peer_id) + .map(|w| w.state.is_pusher) + .unwrap_or(false); + + if is_pusher { + let mut state = round_state.lock().await; state.push_done = true; } @@ -559,24 +590,24 @@ where } TrainStatus::Error(TrainError::Connection { message }) => { tracing::warn!(%peer_id, message = %message, "Worker reported connection error"); - { - let mut state = round_state.lock().await; - if state.push_assigned == Some(peer_id) { - state.push_assigned = None; + + worker_pool.update_state(&peer_id, |s| { + if s.is_pusher { + s.is_pusher = false; } - } + }); + ExecutorAction::Train(TrainAction::Idle { timeout: short_idle, }) } TrainStatus::Error(TrainError::Other { message }) => { tracing::warn!(%peer_id, message = %message, "Worker reported error"); - { - let mut state = round_state.lock().await; - if state.push_assigned == Some(peer_id) { - state.push_assigned = None; + worker_pool.update_state(&peer_id, |s| { + if s.is_pusher { + s.is_pusher = false; } - } + }); ExecutorAction::Train(TrainAction::Terminate) } TrainStatus::Terminated => ExecutorAction::Train(TrainAction::Terminate), @@ -585,7 +616,18 @@ where AggregateStatus::Idle => { let training_complete = { round_state.lock().await.training_complete }; if training_complete { - ExecutorAction::Aggregate(AggregateAction::Terminate) + let all_applied = worker_pool + .properties() + .iter() + .all(|w| w.state.applied_final_update); + + if all_applied { + ExecutorAction::Aggregate(AggregateAction::Terminate) + } else { + ExecutorAction::Aggregate(AggregateAction::Idle { + timeout: short_idle, + }) + } } else if Some(peer_id) != primary_ps { if let Some(primary) = primary_ps { tracing::debug!( @@ -598,16 +640,12 @@ where timeout: short_idle, }) } else { - let workers: Vec<_> = { - let training = training_state.lock().await; - let non_participating_worker = training.get_waiting_workers(); - worker_pool - .statistics() - .into_iter() - .map(|w| w.peer_id) - .filter(|w| !non_participating_worker.contains(w)) - .collect() - }; + let snapshot = worker_pool.properties(); + let workers: Vec<_> = snapshot + .iter() + .filter(|w| !w.state.waiting_for_model) + .map(|w| w.peer_id) + .collect(); if workers.is_empty() { ExecutorAction::Aggregate(AggregateAction::Idle { @@ -618,9 +656,15 @@ where // or when a quorum (min workers) have sent updates and the // grace period has elapsed since the first update in this round. let mut state = round_state.lock().await; - let all_sent = workers.iter().all(|w| state.sent_updates.contains(w)); + let all_sent = snapshot + .iter() + .filter(|w| !w.state.waiting_for_model) + .all(|w| w.state.sent_update); + + let sent_count = snapshot.iter().filter(|w| w.state.sent_update).count(); + let effective_quorum = state.min_quorum.min(workers.len()); - let quorum_met = state.sent_updates.len() >= effective_quorum; + let quorum_met = sent_count >= effective_quorum; let timebox_elapsed = state .first_update_at .map(|t| t.elapsed() >= state.grace) @@ -655,7 +699,7 @@ where }) } else { let workers: Vec<_> = worker_pool - .statistics() + .properties() .into_iter() .map(|w| w.peer_id) .collect(); @@ -669,7 +713,11 @@ where let round = { let mut state = round_state.lock().await; state.aggregated_updates = true; - state.applied_updates.clear(); + // reset applied updates in pool + let snapshot = worker_pool.properties(); + for w in snapshot { + worker_pool.update_state(&w.peer_id, |s| s.applied_update = false); + } state.round }; tracing::info!(round = %round, "Trigger BroadcastUpdate"); @@ -723,7 +771,15 @@ where tracing::info!(round = state.round, "Broadcast completed; advancing round"); - state.sent_updates.clear(); + // Reset per-round state in pool + let snapshot = worker_pool.properties(); + for w in snapshot { + worker_pool.update_state(&w.peer_id, |s| { + s.sent_update = false; + s.samples_processed = 0; + }); + } + state.first_update_at = None; state.aggregate_started_at = None; state.round_started_at = Instant::now(); @@ -748,18 +804,15 @@ where "Next round started; training state reset" ); - let training_complete = { + { let mut state = round_state.lock().await; - state.aggregated_updates = false; - state.training_complete - }; - if training_complete { - ExecutorAction::Aggregate(AggregateAction::Terminate) - } else { - ExecutorAction::Aggregate(AggregateAction::Idle { - timeout: short_idle, - }) + if !state.training_complete { + state.aggregated_updates = false; + } } + ExecutorAction::Aggregate(AggregateAction::Idle { + timeout: short_idle, + }) } AggregateStatus::Error(AggregateError::Connection { message }) => { tracing::warn!(%peer_id, message = %message, "Aggregator reported connection error"); @@ -799,7 +852,7 @@ impl BatchScheduler { #[allow(clippy::too_many_arguments)] pub async fn run( network: Network, - worker_pool: PoolStatisticsHandle, + worker_pool: PoolWithWorkerPropertiesHandle, parameter_pool: PoolHandle, id: Uuid, min_quorum: usize, @@ -826,7 +879,6 @@ impl BatchScheduler { let push_destination = push_destination.clone(); // NOTE: Track per-round SentUpdate signals to decide when to trigger aggregation. let round_state = Arc::new(Mutex::new(RoundState { - sent_updates: HashSet::new(), first_update_at: None, round_started_at: start, aggregate_started_at: None, @@ -834,12 +886,9 @@ impl BatchScheduler { grace, round: 0, update_rounds, - push_assigned: None, training_complete: false, - applied_final_update: HashSet::new(), push_done: false, aggregated_updates: false, - applied_updates: HashSet::new(), })); let training_state = Arc::new(Mutex::new(TrainingState::new(samples_between_updates))); network @@ -903,7 +952,8 @@ impl BatchScheduler { #[cfg(test)] mod batch_scheduler_tests { use std::{ - collections::{HashMap, HashSet}, + collections::HashMap, + sync::Arc, time::{Instant, SystemTime}, }; @@ -911,13 +961,13 @@ mod batch_scheduler_tests { use hypha_messages::{ Reference, SelectionStrategy, action::{ - ActionRequest, AggregateStatus, ExecutorAction, ExecutorStatus, TrainAction, - TrainStatus, + ActionRequest, AggregateAction, AggregateStatus, ExecutorAction, ExecutorStatus, + TrainAction, TrainStatus, }, }; use hypha_resources::Resources; use libp2p::PeerId; - use tokio::time::Duration; + use tokio::{sync::mpsc::Sender, time::Duration}; use tokio_util::sync::CancellationToken; use uuid::Uuid; @@ -925,8 +975,10 @@ mod batch_scheduler_tests { use crate::{ allocator::{Allocator, AllocatorError}, metrics_bridge::Metrics, - pool::{Pool, PoolConfig, PoolWithStatistics}, - scheduler_config::PriceRange, + pool::{ + Pool, PoolConfig, PoolHandle, PoolWithWorkerProperties, PoolWithWorkerPropertiesHandle, + }, + scheduler_config::{ModelDestiantion, PriceRange}, simulation::BasicSimulation, statistics::{RunningMean, RuntimeStatistic}, worker::{TestWorkerBuilder, Worker}, @@ -1016,7 +1068,7 @@ mod batch_scheduler_tests { grace: Duration::from_secs(1), }, ); - let worker_pool = PoolWithStatistics::::new(pool); + let worker_pool = PoolWithWorkerProperties::::new(pool); let worker_handle = worker_pool.handle(); let ps_pool = Pool::new( NoopAllocator, @@ -1035,10 +1087,10 @@ mod batch_scheduler_tests { let parameter_pool = ps_pool.handle(); let (tx, _rx) = tokio::sync::mpsc::channel::<(PeerId, Metrics)>(1); - let round = std::sync::Arc::new(tokio::sync::Mutex::new(RoundState::default())); - let training_state = std::sync::Arc::new(tokio::sync::Mutex::new(TrainingState::new(10))); - let batch_sizer = std::sync::Arc::new(|_: &Resources| 1u32); - let push_destination = std::sync::Arc::new(None); + let round = Arc::new(tokio::sync::Mutex::new(RoundState::default())); + let training_state = Arc::new(tokio::sync::Mutex::new(TrainingState::new(10))); + let batch_sizer = Arc::new(|_: &Resources| 1u32); + let push_destination = Arc::new(None); let token = CancellationToken::new(); let resp = schedule::( tx, @@ -1063,9 +1115,7 @@ mod batch_scheduler_tests { .unwrap(); match resp.next { - hypha_messages::action::ExecutorAction::Train(TrainAction::ExecuteBatch { - batches: 3, - }) => {} + ExecutorAction::Train(TrainAction::ExecuteBatch { batches: 3 }) => {} other => panic!("Unexpected response: {:?}", other), } } @@ -1086,7 +1136,7 @@ mod batch_scheduler_tests { grace: Duration::from_secs(1), }, ); - let worker_pool = PoolWithStatistics::::new(pool); + let worker_pool = PoolWithWorkerProperties::::new(pool); let worker_handle = worker_pool.handle(); let ps_pool = Pool::new( NoopAllocator, @@ -1105,10 +1155,10 @@ mod batch_scheduler_tests { let parameter_pool = ps_pool.handle(); let (tx, _rx) = tokio::sync::mpsc::channel::<(PeerId, Metrics)>(1); - let round = std::sync::Arc::new(tokio::sync::Mutex::new(RoundState::default())); - let training_state = std::sync::Arc::new(tokio::sync::Mutex::new(TrainingState::new(100))); - let batch_sizer = std::sync::Arc::new(|_: &Resources| 1u32); - let push_destination = std::sync::Arc::new(None); + let round = Arc::new(tokio::sync::Mutex::new(RoundState::default())); + let training_state = Arc::new(tokio::sync::Mutex::new(TrainingState::new(100))); + let batch_sizer = Arc::new(|_: &Resources| 1u32); + let push_destination = Arc::new(None); let token = CancellationToken::new(); let resp = schedule::( tx, @@ -1133,9 +1183,7 @@ mod batch_scheduler_tests { .unwrap(); match resp.next { - hypha_messages::action::ExecutorAction::Train(TrainAction::ExecuteBatch { - batches: 3, - }) => {} + ExecutorAction::Train(TrainAction::ExecuteBatch { batches: 3 }) => {} other => panic!("Unexpected response: {:?}", other), } } @@ -1156,7 +1204,7 @@ mod batch_scheduler_tests { grace: Duration::from_secs(1), }, ); - let worker_pool = PoolWithStatistics::::new(pool); + let worker_pool = PoolWithWorkerProperties::::new(pool); let worker_handle = worker_pool.handle(); let ps_pool = Pool::new( NoopAllocator, @@ -1175,10 +1223,10 @@ mod batch_scheduler_tests { let parameter_pool = ps_pool.handle(); let (tx, _rx) = tokio::sync::mpsc::channel::<(PeerId, Metrics)>(1); - let round = std::sync::Arc::new(tokio::sync::Mutex::new(RoundState::default())); - let training_state = std::sync::Arc::new(tokio::sync::Mutex::new(TrainingState::new(10))); - let batch_sizer = std::sync::Arc::new(|_: &Resources| 1u32); - let push_destination = std::sync::Arc::new(None); + let round = Arc::new(tokio::sync::Mutex::new(RoundState::default())); + let training_state = Arc::new(tokio::sync::Mutex::new(TrainingState::new(10))); + let batch_sizer = Arc::new(|_: &Resources| 1u32); + let push_destination = Arc::new(None); let token = CancellationToken::new(); let resp = schedule::( tx, @@ -1203,7 +1251,7 @@ mod batch_scheduler_tests { .unwrap(); match resp.next { - hypha_messages::action::ExecutorAction::Train(TrainAction::Idle { .. }) => {} + ExecutorAction::Train(TrainAction::Idle { .. }) => {} other => panic!("Unexpected response: {:?}", other), } } @@ -1224,7 +1272,7 @@ mod batch_scheduler_tests { grace: Duration::from_secs(1), }, ); - let worker_pool = PoolWithStatistics::::new(pool); + let worker_pool = PoolWithWorkerProperties::::new(pool); let worker_handle = worker_pool.handle(); let ps_pool = Pool::new( NoopAllocator, @@ -1243,11 +1291,11 @@ mod batch_scheduler_tests { let parameter_pool = ps_pool.handle(); let (tx, _rx) = tokio::sync::mpsc::channel::<(PeerId, Metrics)>(1); - let round = std::sync::Arc::new(tokio::sync::Mutex::new(RoundState::default())); + let round = Arc::new(tokio::sync::Mutex::new(RoundState::default())); // Initialize with 0 samples remaining to simulate end of round - let training_state = std::sync::Arc::new(tokio::sync::Mutex::new(TrainingState::new(0))); - let batch_sizer = std::sync::Arc::new(|_: &Resources| 1u32); - let push_destination = std::sync::Arc::new(None); + let training_state = Arc::new(tokio::sync::Mutex::new(TrainingState::new(0))); + let batch_sizer = Arc::new(|_: &Resources| 1u32); + let push_destination = Arc::new(None); let token = CancellationToken::new(); let resp = schedule::( tx, @@ -1273,11 +1321,793 @@ mod batch_scheduler_tests { match resp.next { // Should be Idle because samples == 0 and no PS, not ExecuteBatch - hypha_messages::action::ExecutorAction::Train(TrainAction::Idle { .. }) => {} + ExecutorAction::Train(TrainAction::Idle { .. }) => {} other => panic!("Expected Idle when samples=0 and no PS, got: {:?}", other), } } + #[tokio::test] + async fn final_round_applies_once_and_pushes_model() { + let w1_id = PeerId::random(); + let w2_id = PeerId::random(); + let ps_id = PeerId::random(); + let (tx, _rx) = tokio::sync::mpsc::channel::<(PeerId, Metrics)>(8); + + let worker_allocator = StaticAllocator::new(vec![vec![ + TestWorkerBuilder::new().with_peer_id(w1_id).build(), + TestWorkerBuilder::new().with_peer_id(w2_id).build(), + ]]); + + let mut worker_pool_stream = PoolWithWorkerProperties::::new(Pool::new( + worker_allocator, + PoolConfig { + name: "workers".into(), + spec: hypha_messages::WorkerSpec { + resources: Resources::default(), + executor: vec![], + }, + price: PriceRange::default(), + min: 0, + target: 2, + grace: Duration::from_secs(1), + }, + )); + let worker_handle = worker_pool_stream.handle(); + + for _ in 0..2 { + tokio::time::timeout(Duration::from_secs(1), worker_pool_stream.next()) + .await + .expect("worker pool populate timeout") + .expect("worker pool ended") + .expect("worker allocation failed"); + if worker_pool_stream.properties().len() >= 2 { + break; + } + } + + let parameter_allocator = StaticAllocator::new(vec![vec![ + TestWorkerBuilder::new().with_peer_id(ps_id).build(), + ]]); + + let mut parameter_pool_stream = Pool::new( + parameter_allocator, + PoolConfig { + name: "ps".into(), + spec: hypha_messages::WorkerSpec { + resources: Resources::default(), + executor: vec![], + }, + price: PriceRange::default(), + min: 0, + target: 1, + grace: Duration::from_secs(1), + }, + ); + let parameter_handle = parameter_pool_stream.handle(); + tokio::time::timeout(Duration::from_secs(1), parameter_pool_stream.next()) + .await + .expect("parameter pool populate timeout") + .expect("parameter pool ended") + .expect("ps allocation failed"); + + let round_state = Arc::new(tokio::sync::Mutex::new(RoundState { + aggregated_updates: false, + first_update_at: None, + round_started_at: Instant::now(), + aggregate_started_at: None, + min_quorum: 2, + grace: Duration::from_millis(0), + round: 0, + update_rounds: 2, + training_complete: false, + push_done: false, + })); + let training_state = Arc::new(tokio::sync::Mutex::new(TrainingState::new(1))); + let batch_sizer = Arc::new(|_: &Resources| 1u32); + let push_destination = Arc::new(Some(ModelDestiantion { + repository: "hf/repo".to_string(), + token: "token".to_string(), + })); + let start = std::time::Instant::now(); + let token = CancellationToken::new(); + + struct Step { + peer: PeerId, + status: ExecutorStatus, + check: Box, + } + + async fn dispatch( + peer: PeerId, + status: ExecutorStatus, + worker_handle: PoolWithWorkerPropertiesHandle, + parameter_handle: PoolHandle, + round_state: Arc>, + training_state: Arc>, + batch_sizer: Arc u32 + Send + Sync>, + push_destination: Arc>, + start: Instant, + tx: Sender<(PeerId, Metrics)>, + token: CancellationToken, + ) -> ExecutorAction { + schedule::( + tx, + worker_handle, + parameter_handle, + round_state, + training_state, + batch_sizer, + 1, + push_destination, + start, + ( + peer, + ActionRequest { + job_id: Uuid::new_v4(), + status, + }, + ), + token, + ) + .await + .unwrap() + .next + } + + let steps: Vec = vec![ + // Round 0 + Step { + peer: w2_id, + status: ExecutorStatus::Train(TrainStatus::BatchCompleted { batch_size: 1 }), + check: Box::new(|resp| match resp { + ExecutorAction::Train(TrainAction::ExecuteBatch { batches: 1 }) + | ExecutorAction::Train(TrainAction::SendUpdate { .. }) => {} + other => panic!("expected ExecuteBatch or SendUpdate, got {:?}", other), + }), + }, + Step { + peer: w2_id, + status: ExecutorStatus::Train(TrainStatus::BatchCompleted { batch_size: 1 }), + check: Box::new(move |resp| match resp { + ExecutorAction::Train(TrainAction::SendUpdate { + target: Reference::Peers { peers, .. }, + .. + }) => assert_eq!(peers, vec![ps_id]), + other => panic!("expected SendUpdate, got {:?}", other), + }), + }, + Step { + peer: w2_id, + status: ExecutorStatus::Train(TrainStatus::SentUpdate { + round: 0, + metrics: HashMap::new(), + }), + check: Box::new(|resp| { + assert!(matches!( + resp, + ExecutorAction::Train(TrainAction::Idle { .. }) + )) + }), + }, + Step { + peer: w1_id, + status: ExecutorStatus::Train(TrainStatus::BatchCompleted { batch_size: 1 }), + check: Box::new(move |resp| match resp { + ExecutorAction::Train(TrainAction::SendUpdate { + target: Reference::Peers { peers, .. }, + .. + }) => assert_eq!(peers, vec![ps_id]), + other => panic!("expected SendUpdate, got {:?}", other), + }), + }, + Step { + peer: w1_id, + status: ExecutorStatus::Train(TrainStatus::SentUpdate { + round: 0, + metrics: HashMap::new(), + }), + check: Box::new(|resp| { + assert!(matches!( + resp, + ExecutorAction::Train(TrainAction::Idle { .. }) + )) + }), + }, + Step { + peer: ps_id, + status: ExecutorStatus::Aggregate(AggregateStatus::Idle), + check: Box::new(move |resp| match resp { + ExecutorAction::Aggregate(AggregateAction::AggregateUpdates { + source: Reference::Peers { peers, .. }, + }) => assert_eq!(peers, vec![w1_id, w2_id]), + other => panic!("expected AggregateUpdates, got {:?}", other), + }), + }, + Step { + peer: ps_id, + status: ExecutorStatus::Aggregate(AggregateStatus::AggregatedUpdates { + metrics: None, + }), + check: Box::new(move |resp| match resp { + ExecutorAction::Aggregate(AggregateAction::BroadcastUpdate { + target: Reference::Peers { peers, .. }, + }) => assert_eq!(peers, vec![w1_id, w2_id]), + other => panic!("expected BroadcastUpdate, got {:?}", other), + }), + }, + Step { + peer: ps_id, + status: ExecutorStatus::Aggregate(AggregateStatus::BroadcastedUpdate { + metrics: None, + }), + check: Box::new(|resp| { + assert!(matches!( + resp, + ExecutorAction::Aggregate(AggregateAction::Idle { .. }) + )) + }), + }, + // Round 1 + Step { + peer: w1_id, + status: ExecutorStatus::Train(TrainStatus::BatchCompleted { batch_size: 1 }), + check: Box::new(|resp| match resp { + ExecutorAction::Train(TrainAction::ExecuteBatch { batches: 1 }) + | ExecutorAction::Train(TrainAction::SendUpdate { .. }) => {} + other => panic!("expected ExecuteBatch or SendUpdate, got {:?}", other), + }), + }, + Step { + peer: w1_id, + status: ExecutorStatus::Train(TrainStatus::BatchCompleted { batch_size: 1 }), + check: Box::new(move |resp| match resp { + ExecutorAction::Train(TrainAction::SendUpdate { + target: Reference::Peers { peers, .. }, + .. + }) => assert_eq!(peers, vec![ps_id]), + other => panic!("expected SendUpdate, got {:?}", other), + }), + }, + Step { + peer: w1_id, + status: ExecutorStatus::Train(TrainStatus::SentUpdate { + round: 1, + metrics: HashMap::new(), + }), + check: Box::new(|resp| { + assert!(matches!( + resp, + ExecutorAction::Train(TrainAction::Idle { .. }) + )) + }), + }, + Step { + peer: w2_id, + status: ExecutorStatus::Train(TrainStatus::BatchCompleted { batch_size: 1 }), + check: Box::new(move |resp| match resp { + ExecutorAction::Train(TrainAction::SendUpdate { + target: Reference::Peers { peers, .. }, + .. + }) => assert_eq!(peers, vec![ps_id]), + other => panic!("expected SendUpdate, got {:?}", other), + }), + }, + Step { + peer: w2_id, + status: ExecutorStatus::Train(TrainStatus::SentUpdate { + round: 1, + metrics: HashMap::new(), + }), + check: Box::new(|resp| { + assert!(matches!( + resp, + ExecutorAction::Train(TrainAction::Idle { .. }) + )) + }), + }, + Step { + peer: ps_id, + status: ExecutorStatus::Aggregate(AggregateStatus::Idle), + check: Box::new(move |resp| match resp { + ExecutorAction::Aggregate(AggregateAction::AggregateUpdates { + source: Reference::Peers { peers, .. }, + }) => assert_eq!(peers, vec![w1_id, w2_id]), + other => panic!("expected AggregateUpdates, got {:?}", other), + }), + }, + Step { + peer: ps_id, + status: ExecutorStatus::Aggregate(AggregateStatus::AggregatedUpdates { + metrics: None, + }), + check: Box::new(|resp| { + assert!(matches!( + resp, + ExecutorAction::Aggregate(AggregateAction::BroadcastUpdate { .. }) + )) + }), + }, + Step { + peer: ps_id, + status: ExecutorStatus::Aggregate(AggregateStatus::BroadcastedUpdate { + metrics: None, + }), + check: Box::new(|resp| { + assert!(matches!( + resp, + ExecutorAction::Aggregate(AggregateAction::Idle { .. }) + )) + }), + }, + // Completion + Step { + peer: w1_id, + status: ExecutorStatus::Train(TrainStatus::Idle), + check: Box::new(|resp| match resp { + ExecutorAction::Train(TrainAction::ApplyUpdate { .. }) + | ExecutorAction::Train(TrainAction::Idle { .. }) => {} + other => panic!("expected to apply update when aggregated, got {:?}", other), + }), + }, + Step { + peer: w1_id, + status: ExecutorStatus::Train(TrainStatus::AppliedUpdate), + check: Box::new(|resp| { + assert!(matches!( + resp, + ExecutorAction::Train(TrainAction::Idle { .. }) + )) + }), + }, + Step { + peer: w1_id, + status: ExecutorStatus::Train(TrainStatus::Idle), + check: Box::new(move |resp| match resp { + ExecutorAction::Train(TrainAction::PushToHub { repository, .. }) => { + assert_eq!(repository, "hf/repo"); + } + other => panic!("expected PushToHub, got {:?}", other), + }), + }, + Step { + peer: w1_id, + status: ExecutorStatus::Train(TrainStatus::PushedToHub), + check: Box::new(|resp| { + assert!(matches!( + resp, + ExecutorAction::Train(TrainAction::Terminate) + )) + }), + }, + Step { + peer: w2_id, + status: ExecutorStatus::Train(TrainStatus::Idle), + check: Box::new(|resp| match resp { + ExecutorAction::Train(TrainAction::ApplyUpdate { .. }) + | ExecutorAction::Train(TrainAction::ExecuteBatch { batches: 1 }) + | ExecutorAction::Train(TrainAction::Terminate) => {} + other => panic!( + "rejoined worker should be ready to apply or execute, got {:?}", + other + ), + }), + }, + Step { + peer: w2_id, + status: ExecutorStatus::Train(TrainStatus::AppliedUpdate), + check: Box::new(|resp| { + assert!(matches!( + resp, + ExecutorAction::Train(TrainAction::Idle { .. }) + )) + }), + }, + Step { + peer: w2_id, + status: ExecutorStatus::Train(TrainStatus::Idle), + check: Box::new(|resp| { + assert!(matches!( + resp, + ExecutorAction::Train(TrainAction::Terminate) + )) + }), + }, + ]; + + for step in steps { + let resp = dispatch( + step.peer, + step.status, + worker_handle.clone(), + parameter_handle.clone(), + round_state.clone(), + training_state.clone(), + batch_sizer.clone(), + push_destination.clone(), + start, + tx.clone(), + token.clone(), + ) + .await; + (step.check)(resp); + } + + let snapshot = worker_handle.properties(); + assert_eq!( + snapshot.iter().filter(|w| w.state.is_pusher).count(), + 1, + "only one worker should push the final model" + ); + assert!( + snapshot.iter().all(|w| w.state.applied_final_update), + "both workers should have applied the final update" + ); + let state = round_state.lock().await; + assert_eq!(state.round, 2, "should complete both rounds"); + assert!(state.training_complete, "training should be complete"); + assert!(state.push_done, "push should be marked done"); + } + + #[tokio::test] + async fn rejoining_worker_receives_model_before_participating() { + let w1_id = PeerId::random(); + let w2_id = PeerId::random(); + let ps_id = PeerId::random(); + let (tx, _rx) = tokio::sync::mpsc::channel::<(PeerId, Metrics)>(8); + + let worker_allocator = StaticAllocator::new(vec![vec![ + TestWorkerBuilder::new().with_peer_id(w1_id).build(), + TestWorkerBuilder::new().with_peer_id(w2_id).build(), + ]]); + + let mut worker_pool_stream = PoolWithWorkerProperties::::new(Pool::new( + worker_allocator, + PoolConfig { + name: "workers".into(), + spec: hypha_messages::WorkerSpec { + resources: Resources::default(), + executor: vec![], + }, + price: PriceRange::default(), + min: 0, + target: 2, + grace: Duration::from_secs(1), + }, + )); + let worker_handle = worker_pool_stream.handle(); + + for _ in 0..2 { + tokio::time::timeout(Duration::from_secs(1), worker_pool_stream.next()) + .await + .expect("worker pool populate timeout") + .expect("worker pool ended") + .expect("worker allocation failed"); + if worker_pool_stream.properties().len() >= 2 { + break; + } + } + + let parameter_allocator = StaticAllocator::new(vec![vec![ + TestWorkerBuilder::new().with_peer_id(ps_id).build(), + ]]); + + let mut parameter_pool_stream = Pool::new( + parameter_allocator, + PoolConfig { + name: "ps".into(), + spec: hypha_messages::WorkerSpec { + resources: Resources::default(), + executor: vec![], + }, + price: PriceRange::default(), + min: 0, + target: 1, + grace: Duration::from_secs(1), + }, + ); + let parameter_handle = parameter_pool_stream.handle(); + tokio::time::timeout(Duration::from_secs(1), parameter_pool_stream.next()) + .await + .expect("parameter pool populate timeout") + .expect("parameter pool ended") + .expect("ps allocation failed"); + + let round_state = Arc::new(tokio::sync::Mutex::new(RoundState { + aggregated_updates: false, + first_update_at: None, + round_started_at: Instant::now(), + aggregate_started_at: None, + min_quorum: 1, + grace: Duration::from_millis(0), + round: 0, + update_rounds: 3, + training_complete: false, + push_done: false, + })); + let training_state = Arc::new(tokio::sync::Mutex::new(TrainingState::new(1))); + let batch_sizer = Arc::new(|_: &Resources| 1u32); + let push_destination = Arc::new(None); + let start = std::time::Instant::now(); + let token = CancellationToken::new(); + + struct Step { + peer: PeerId, + status: ExecutorStatus, + check: Box, + } + + async fn dispatch( + peer: PeerId, + status: ExecutorStatus, + worker_handle: PoolWithWorkerPropertiesHandle, + parameter_handle: PoolHandle, + round_state: Arc>, + training_state: Arc>, + batch_sizer: Arc u32 + Send + Sync>, + push_destination: Arc>, + start: Instant, + tx: Sender<(PeerId, Metrics)>, + token: CancellationToken, + ) -> ExecutorAction { + schedule::( + tx, + worker_handle, + parameter_handle, + round_state, + training_state, + batch_sizer, + 1, + push_destination, + start, + ( + peer, + ActionRequest { + job_id: Uuid::new_v4(), + status, + }, + ), + token, + ) + .await + .unwrap() + .next + } + + let steps: Vec = vec![ + // Round 0 + Step { + peer: w1_id, + status: ExecutorStatus::Train(TrainStatus::BatchCompleted { batch_size: 1 }), + check: Box::new(move |resp| match resp { + ExecutorAction::Train(TrainAction::SendUpdate { + target: Reference::Peers { peers, .. }, + .. + }) => assert_eq!(peers, vec![ps_id]), + other => panic!("expected SendUpdate, got {:?}", other), + }), + }, + Step { + peer: w1_id, + status: ExecutorStatus::Train(TrainStatus::SentUpdate { + round: 0, + metrics: HashMap::new(), + }), + check: Box::new(|resp| { + assert!(matches!( + resp, + ExecutorAction::Train(TrainAction::Idle { .. }) + )) + }), + }, + Step { + peer: ps_id, + status: ExecutorStatus::Aggregate(AggregateStatus::Idle), + check: Box::new(|resp| { + assert!(matches!( + resp, + ExecutorAction::Aggregate(AggregateAction::AggregateUpdates { .. }) + )) + }), + }, + Step { + peer: ps_id, + status: ExecutorStatus::Aggregate(AggregateStatus::AggregatedUpdates { + metrics: None, + }), + check: Box::new(|resp| { + assert!(matches!( + resp, + ExecutorAction::Aggregate(AggregateAction::BroadcastUpdate { .. }) + )) + }), + }, + Step { + peer: ps_id, + status: ExecutorStatus::Aggregate(AggregateStatus::BroadcastedUpdate { + metrics: None, + }), + check: Box::new(|resp| { + assert!(matches!( + resp, + ExecutorAction::Aggregate(AggregateAction::Idle { .. }) + )) + }), + }, + // w2 joins late + Step { + peer: w2_id, + status: ExecutorStatus::Train(TrainStatus::Joined), + check: Box::new(|resp| { + assert!(matches!( + resp, + ExecutorAction::Train(TrainAction::WaitForModel { .. }) + )) + }), + }, + // Round 1 with w1 producing the update + Step { + peer: w1_id, + status: ExecutorStatus::Train(TrainStatus::BatchCompleted { batch_size: 1 }), + check: Box::new(move |resp| match resp { + ExecutorAction::Train(TrainAction::SendUpdate { + target: Reference::Peers { peers, .. }, + .. + }) => assert_eq!(peers, vec![ps_id]), + other => panic!("expected SendUpdate, got {:?}", other), + }), + }, + Step { + peer: w1_id, + status: ExecutorStatus::Train(TrainStatus::SentUpdate { + round: 1, + metrics: HashMap::new(), + }), + check: Box::new(|resp| { + assert!(matches!( + resp, + ExecutorAction::Train(TrainAction::Idle { .. }) + )) + }), + }, + Step { + peer: ps_id, + status: ExecutorStatus::Aggregate(AggregateStatus::Idle), + check: Box::new(|resp| { + assert!(matches!( + resp, + ExecutorAction::Aggregate(AggregateAction::AggregateUpdates { .. }) + )) + }), + }, + Step { + peer: ps_id, + status: ExecutorStatus::Aggregate(AggregateStatus::AggregatedUpdates { + metrics: None, + }), + check: Box::new(|resp| { + assert!(matches!( + resp, + ExecutorAction::Aggregate(AggregateAction::BroadcastUpdate { .. }) + )) + }), + }, + Step { + peer: ps_id, + status: ExecutorStatus::Aggregate(AggregateStatus::BroadcastedUpdate { + metrics: None, + }), + check: Box::new(|resp| { + assert!(matches!( + resp, + ExecutorAction::Aggregate(AggregateAction::Idle { .. }) + )) + }), + }, + // w1 applies and transfers model to w2 + Step { + peer: w1_id, + status: ExecutorStatus::Train(TrainStatus::Idle), + check: Box::new(|resp| match resp { + ExecutorAction::Train(TrainAction::ApplyUpdate { .. }) + | ExecutorAction::Train(TrainAction::Idle { .. }) + | ExecutorAction::Train(TrainAction::ExecuteBatch { batches: 1 }) => {} + other => panic!( + "expected to apply/idle/execute before sending model, got {:?}", + other + ), + }), + }, + Step { + peer: w1_id, + status: ExecutorStatus::Train(TrainStatus::AppliedUpdate), + check: Box::new(move |resp| match resp { + ExecutorAction::Train(TrainAction::SendModel { + target: Reference::Peers { peers, .. }, + }) => assert_eq!(peers, vec![w2_id]), + other => panic!( + "worker with update should send model to waiting peer, got {:?}", + other + ), + }), + }, + Step { + peer: w2_id, + status: ExecutorStatus::Train(TrainStatus::WaitedForModel), + check: Box::new(move |resp| match resp { + ExecutorAction::Train(TrainAction::ReceiveModel { + source: Reference::Peers { peers, .. }, + .. + }) => assert_eq!(peers, vec![w1_id]), + other => panic!("waiting worker should receive model, got {:?}", other), + }), + }, + Step { + peer: w2_id, + status: ExecutorStatus::Train(TrainStatus::ReceivedModel), + check: Box::new(|resp| { + assert!(matches!( + resp, + ExecutorAction::Train(TrainAction::Idle { .. }) + )) + }), + }, + Step { + peer: w2_id, + status: ExecutorStatus::Train(TrainStatus::Idle), + check: Box::new(|resp| match resp { + ExecutorAction::Train(TrainAction::ApplyUpdate { .. }) + | ExecutorAction::Train(TrainAction::ExecuteBatch { batches: 1 }) => {} + other => panic!("rejoined worker should apply or execute, got {:?}", other), + }), + }, + Step { + peer: w2_id, + status: ExecutorStatus::Train(TrainStatus::AppliedUpdate), + check: Box::new(|resp| match resp { + ExecutorAction::Train(TrainAction::ExecuteBatch { batches }) => { + assert_eq!(batches, 1); + } + other => panic!( + "rejoined worker should execute batches after applying, got {:?}", + other + ), + }), + }, + ]; + + for step in steps { + let resp = dispatch( + step.peer, + step.status, + worker_handle.clone(), + parameter_handle.clone(), + round_state.clone(), + training_state.clone(), + batch_sizer.clone(), + push_destination.clone(), + start, + tx.clone(), + token.clone(), + ) + .await; + (step.check)(resp); + } + + let snapshot = worker_handle.properties(); + let rejoined = snapshot + .iter() + .find(|w| w.peer_id == w2_id) + .expect("rejoined worker present"); + assert!( + !rejoined.state.waiting_for_model, + "waiting flag should be cleared after receiving model" + ); + assert!( + rejoined.state.applied_update, + "rejoined worker should have applied the current update" + ); + } + #[tokio::test] async fn handle_simple_two_rounds_simulation_drives_updates() { struct Step { @@ -1323,7 +2153,7 @@ mod batch_scheduler_tests { .build(), ]]); - let mut worker_pool_stream = PoolWithStatistics::::new(Pool::new( + let mut worker_pool_stream = PoolWithWorkerProperties::::new(Pool::new( worker_allocator, PoolConfig { name: "workers".into(), @@ -1365,7 +2195,7 @@ mod batch_scheduler_tests { .expect("worker pool populate timeout") .expect("worker pool ended") .expect("worker allocation failed"); - if worker_pool_stream.statistics().len() >= 3 { + if worker_pool_stream.properties().len() >= 3 { break; } } @@ -1378,7 +2208,6 @@ mod batch_scheduler_tests { let (tx, _rx) = tokio::sync::mpsc::channel::<(PeerId, Metrics)>(8); let round_state = std::sync::Arc::new(tokio::sync::Mutex::new(RoundState { - sent_updates: Default::default(), first_update_at: None, round_started_at: Instant::now(), aggregate_started_at: None, @@ -1386,12 +2215,9 @@ mod batch_scheduler_tests { grace: Duration::from_millis(0), round: 0, update_rounds: 10, - push_assigned: None, training_complete: false, - applied_final_update: Default::default(), push_done: false, aggregated_updates: false, - applied_updates: HashSet::default(), })); let training_state = std::sync::Arc::new(tokio::sync::Mutex::new(TrainingState::new(800))); let batch_sizer = std::sync::Arc::new(|resources: &Resources| resources.gpu() as u32); @@ -1652,8 +2478,10 @@ mod batch_scheduler_tests { { let state = round_state.lock().await; assert_eq!(state.round, 1, "round advanced after broadcast"); + + let snapshot = worker_handle.properties(); assert!( - state.sent_updates.is_empty(), + snapshot.iter().all(|w| !w.state.sent_update), "sent updates cleared after broadcast" ); } From ce2bbc0a79f744fb982780eb781dc338be33ddda Mon Sep 17 00:00:00 2001 From: l45k Date: Mon, 29 Dec 2025 22:49:46 +0100 Subject: [PATCH 4/5] refactor: common transitions in scheduling --- .../src/scheduling/batch_scheduler.rs | 277 ++++++++---------- 1 file changed, 119 insertions(+), 158 deletions(-) diff --git a/crates/scheduler/src/scheduling/batch_scheduler.rs b/crates/scheduler/src/scheduling/batch_scheduler.rs index 8b88afe5..53c3199f 100644 --- a/crates/scheduler/src/scheduling/batch_scheduler.rs +++ b/crates/scheduler/src/scheduling/batch_scheduler.rs @@ -27,7 +27,7 @@ use uuid::Uuid; use crate::{ metrics_bridge::Metrics, network::Network, - pool::{PoolHandle, PoolWithWorkerPropertiesHandle}, + pool::{PoolHandle, PoolWithWorkerPropertiesHandle, WorkerDescriptorWithProperties}, scheduler_config::ModelDestiantion, simulation::Simulation, statistics::RuntimeStatistic, @@ -112,6 +112,103 @@ pub enum BatchSchedulerError { SendMetricsError(#[from] SendError<(PeerId, Metrics)>), } +/// Handle common transitions in Idle and BatchCompleted +#[allow(clippy::too_many_arguments)] +async fn handle_common_transition( + training_state: Arc>, + batch_sizer: BatchSizer, + multi_batch_size: u32, + snapshot: Vec, + peer_id: PeerId, + parameter_servers: Vec, + short_idle: SystemTime, +) -> Result +where + S: Simulation + Send + Sync + 'static, +{ + let peer_position = snapshot + .iter() + .position(|w| w.peer_id == peer_id) + .unwrap_or(0); + + // Get peer contribution from pool + let peer_contribution = snapshot + .iter() + .find(|w| w.peer_id == peer_id) + .map(|w| w.state.samples_processed) + .unwrap_or(0); + + let (count, update_target) = { + let training = training_state.lock().await; + (training.get_count(), training.get_update_target()) + }; + + let stats: Vec = snapshot.iter().map(|w| w.statistic.unwrap_or(0)).collect(); + let progress: Vec = snapshot + .iter() + .map(|w| w.last_updated.unwrap_or(0)) + .collect(); + let batch_sizes: Vec = snapshot + .iter() + .map(|w| (batch_sizer)(&w.resources)) + .collect(); + + let (should_update, projected_target, batches) = if update_target <= count { + (true, count, 0) + } else if !snapshot.is_empty() + && batch_sizes.iter().all(|&b| b > 0) + && stats.iter().all(|&s| s > 0 && s < u64::MAX) + { + let (time, cnt, projection, capped) = S::project( + &progress, + &batch_sizes, + stats, + update_target.saturating_sub(count), + multi_batch_size, + ); + + tracing::debug!( + time = %time, + count = %cnt, + peer = %peer_id, + capped, + "Simulation with projection {:?} and {:?}", + projection, + update_target.saturating_sub(count) + ); + ( + cnt <= 0 + && !capped + && peer_position < projection.len() + && projection[peer_position] == 0, + count.saturating_add(cnt.unsigned_abs()), + projection[peer_position], + ) + } else { + (false, count, multi_batch_size) + }; + + if !should_update { + Ok(ExecutorAction::Train(TrainAction::ExecuteBatch { batches })) + } else if parameter_servers.is_empty() { + // NOTE: If we need to send an update but there are no parameter servers, + // we must wait (idle) until one becomes available. + Ok(ExecutorAction::Train(TrainAction::Idle { + timeout: short_idle, + })) + } else { + Ok(ExecutorAction::Train(TrainAction::SendUpdate { + target: Reference::Peers { + // Selecting a single PS to avoid that workers send updates to multiple PS + peers: vec![parameter_servers[0]], + strategy: SelectionStrategy::One, + resource: None, + }, + weight: peer_contribution as f32 / projected_target as f32, + })) + } +} + /// Handle action protocol requests and respond with next steps. #[allow(clippy::too_many_arguments)] async fn schedule( @@ -197,66 +294,6 @@ where let mut state = round_state.lock().await; if !state.training_complete { let snapshot = worker_pool.properties(); - let peer_position = snapshot - .iter() - .position(|w| w.peer_id == peer_id) - .unwrap_or(0); - - // Get peer contribution from pool - let peer_contribution = snapshot - .iter() - .find(|w| w.peer_id == peer_id) - .map(|w| w.state.samples_processed) - .unwrap_or(0); - - let (count, update_target) = { - let training = training_state.lock().await; - (training.get_count(), training.get_update_target()) - }; - - let stats: Vec = - snapshot.iter().map(|w| w.statistic.unwrap_or(0)).collect(); - let progress: Vec = snapshot - .iter() - .map(|w| w.last_updated.unwrap_or(0)) - .collect(); - let batch_sizes: Vec = snapshot - .iter() - .map(|w| (batch_sizer)(&w.resources)) - .collect(); - - let (should_update, projected_target) = if update_target <= count { - (true, count) - } else if !snapshot.is_empty() - && batch_sizes.iter().all(|&b| b > 0) - && stats.iter().all(|&s| s > 0 && s < u64::MAX) - { - let (time, cnt, projection, capped) = S::project( - &progress, - &batch_sizes, - stats, - update_target.saturating_sub(count), - multi_batch_size, - ); - - tracing::debug!( - time = %time, - count = %cnt, - peer = %peer_id, - "Simulation with projection {:?} and {:?}", - projection, - update_target.saturating_sub(count) - ); - ( - cnt <= 0 - && !capped - && peer_position < projection.len() - && projection[peer_position] == 0, - count.saturating_add(cnt.unsigned_abs()), - ) - } else { - (false, count) - }; // Check if peer has applied update or sent update let (has_applied_update, has_sent_update, _applied_final_update) = snapshot @@ -284,26 +321,17 @@ where ExecutorAction::Train(TrainAction::Idle { timeout: short_idle, }) - } else if !should_update { - ExecutorAction::Train(TrainAction::ExecuteBatch { - batches: multi_batch_size, - }) - } else if parameter_servers.is_empty() { - // NOTE: If we need to send an update but there are no parameter servers, - // we must wait (idle) until one becomes available. - ExecutorAction::Train(TrainAction::Idle { - timeout: short_idle, - }) } else { - ExecutorAction::Train(TrainAction::SendUpdate { - target: Reference::Peers { - // Selecting a single PS to avoid that workers send updates to multiple PS - peers: vec![parameter_servers[0]], - strategy: SelectionStrategy::One, - resource: None, - }, - weight: peer_contribution as f32 / projected_target as f32, - }) + handle_common_transition::( + training_state, + batch_sizer, + multi_batch_size, + snapshot, + peer_id, + parameter_servers, + short_idle, + ) + .await? } } else if state.push_done { cancel.cancel(); @@ -362,15 +390,10 @@ where }); let snapshot = worker_pool.properties(); - let peer_position = snapshot - .iter() - .position(|w| w.peer_id == peer_id) - .unwrap_or(0); - let (count, update_target) = { + { let mut training = training_state.lock().await; training.record_batch(batch_size); - (training.get_count(), training.get_update_target()) }; // Update per-worker samples @@ -378,13 +401,6 @@ where s.samples_processed = s.samples_processed.saturating_add(batch_size); }); - // Get fresh snapshot for peer contribution after update - let peer_contribution = snapshot - .iter() - .find(|w| w.peer_id == peer_id) - .map(|w| w.state.samples_processed.saturating_add(batch_size)) - .unwrap_or(batch_size); - let (training_complete, sent_update) = { let state = round_state.lock().await; let sent = snapshot @@ -400,71 +416,16 @@ where timeout: short_idle, }) } else { - let stats: Vec = - snapshot.iter().map(|w| w.statistic.unwrap_or(0)).collect(); - let progress: Vec = snapshot - .iter() - .map(|w| w.last_updated.unwrap_or(0)) - .collect(); - let batch_sizes: Vec = snapshot - .iter() - .map(|w| (batch_sizer)(&w.resources)) - .collect(); - - let (should_update, projected_target, batches) = if update_target <= count { - (true, count, 0) - } else if !snapshot.is_empty() - && batch_sizes.iter().all(|&b| b > 0) - && stats.iter().all(|&s| s > 0 && s < u64::MAX) - { - let (time, cnt, projection, capped) = S::project( - &progress, - &batch_sizes, - stats, - update_target.saturating_sub(count), - multi_batch_size, - ); - - tracing::debug!( - time = %time, - count = %cnt, - peer = %peer_id, - capped, - "Simulation with projection {:?} and {:?}", - projection, - update_target.saturating_sub(count) - ); - ( - cnt <= 0 - && !capped - && peer_position < projection.len() - && projection[peer_position] == 0, - count.saturating_add(cnt.unsigned_abs()), - projection[peer_position], - ) - } else { - (false, count, multi_batch_size) - }; - - if !should_update { - ExecutorAction::Train(TrainAction::ExecuteBatch { batches }) - } else if parameter_servers.is_empty() { - // NOTE: If we need to send an update but there are no parameter servers, - // we must wait (idle) until one becomes available. - ExecutorAction::Train(TrainAction::Idle { - timeout: short_idle, - }) - } else { - ExecutorAction::Train(TrainAction::SendUpdate { - target: Reference::Peers { - // Selecting a single PS to avoid that workers send updates to multiple PS - peers: vec![parameter_servers[0]], - strategy: SelectionStrategy::One, - resource: None, - }, - weight: peer_contribution as f32 / projected_target as f32, - }) - } + handle_common_transition::( + training_state, + batch_sizer, + multi_batch_size, + snapshot, + peer_id, + parameter_servers, + short_idle, + ) + .await? } } TrainStatus::SentUpdate { mut metrics, .. } => { From 34b38eedd4122fc506a3137f31ba12eb99bc8a0c Mon Sep 17 00:00:00 2001 From: l45k Date: Wed, 31 Dec 2025 17:40:10 +0100 Subject: [PATCH 5/5] wip --- crates/messages/src/lib.rs | 13 + crates/scheduler/src/pool.rs | 5 +- crates/scheduler/src/scheduler_config.rs | 4 +- .../src/scheduling/batch_scheduler.rs | 587 +++++++++++------- crates/scheduler/src/statistics.rs | 51 +- .../worker/src/executor/parameter_server.rs | 2 +- .../src/hypha/accelerate_executor/training.py | 21 +- 7 files changed, 431 insertions(+), 252 deletions(-) diff --git a/crates/messages/src/lib.rs b/crates/messages/src/lib.rs index 8e7b9ac7..195c22ac 100644 --- a/crates/messages/src/lib.rs +++ b/crates/messages/src/lib.rs @@ -98,12 +98,16 @@ pub mod action { Idle, BatchCompleted { batch_size: u32, + batches: u32, }, + WaitedForParameterServer, SentUpdate { round: u32, metrics: HashMap, }, + WaitedForUpdate, AppliedUpdate, + WaitedForNextRound, PushedToHub, SentModel, ReceivedModel, @@ -167,14 +171,23 @@ pub mod action { ExecuteBatch { batches: u32, }, + WaitForParameterServer { + timeout: SystemTime, + }, SendUpdate { target: Reference, weight: f32, }, + WaitForUpdate { + timeout: SystemTime, + }, ApplyUpdate { source: Reference, timeout: SystemTime, }, + WaitForNextRound { + timeout: SystemTime, + }, /// DEPRECATED: Temporary path to push final weights to Hugging Face. /// Prefer dedicated artifact publishing in future revisions. PushToHub { diff --git a/crates/scheduler/src/pool.rs b/crates/scheduler/src/pool.rs index e7c7e1d3..820467ee 100644 --- a/crates/scheduler/src/pool.rs +++ b/crates/scheduler/src/pool.rs @@ -424,7 +424,6 @@ impl WorkerDescriptorWithProperties { pub struct WorkerState { pub sent_update: bool, pub applied_update: bool, - pub applied_final_update: bool, pub samples_processed: u32, pub waiting_for_model: bool, pub receiving_from: Option, @@ -811,13 +810,13 @@ mod tests { pool_with_stats.update_statistics(&peer_id, |stats, last_updated| { if *last_updated > 0 { - stats.update(10u64.saturating_sub(*last_updated)); + stats.update(10u64.saturating_sub(*last_updated), 1); } *last_updated = 10; }); pool_with_stats.update_statistics(&peer_id, |stats, last_updated| { if *last_updated > 0 { - stats.update(25u64.saturating_sub(*last_updated)); + stats.update(25u64.saturating_sub(*last_updated), 1); } *last_updated = 25; }); diff --git a/crates/scheduler/src/scheduler_config.rs b/crates/scheduler/src/scheduler_config.rs index f1c58869..dd5ff7c0 100644 --- a/crates/scheduler/src/scheduler_config.rs +++ b/crates/scheduler/src/scheduler_config.rs @@ -56,7 +56,7 @@ pub struct DiLoCo { #[serde(rename = "outer_optimizer")] pub outer_optimizer: Nesterov, pub resources: DiLoCoResources, - pub model_destination: Option, + pub model_destination: Option, } #[derive(Deserialize, Serialize, Debug, Clone, Copy)] @@ -179,7 +179,7 @@ impl From for Model { } #[derive(Deserialize, Serialize, Debug, Clone)] -pub struct ModelDestiantion { +pub struct ModelDestination { pub repository: String, pub token: String, } diff --git a/crates/scheduler/src/scheduling/batch_scheduler.rs b/crates/scheduler/src/scheduling/batch_scheduler.rs index 53c3199f..e2bebd39 100644 --- a/crates/scheduler/src/scheduling/batch_scheduler.rs +++ b/crates/scheduler/src/scheduling/batch_scheduler.rs @@ -28,7 +28,7 @@ use crate::{ metrics_bridge::Metrics, network::Network, pool::{PoolHandle, PoolWithWorkerPropertiesHandle, WorkerDescriptorWithProperties}, - scheduler_config::ModelDestiantion, + scheduler_config::ModelDestination, simulation::Simulation, statistics::RuntimeStatistic, }; @@ -46,6 +46,7 @@ struct RoundState { update_rounds: u32, training_complete: bool, push_done: bool, + projected_taret: u32, } impl Default for RoundState { @@ -61,6 +62,7 @@ impl Default for RoundState { update_rounds: 0, training_complete: false, push_done: false, + projected_taret: 0, } } } @@ -112,10 +114,36 @@ pub enum BatchSchedulerError { SendMetricsError(#[from] SendError<(PeerId, Metrics)>), } +async fn handle_send_update( + parameter_servers: Vec, + round_state: Arc>, + peer_contribution: u32, + short_idle: SystemTime, +) -> Result { + if parameter_servers.is_empty() { + // NOTE: If we need to send an update but there are no parameter servers, + // we must wait until one becomes available. + Ok(ExecutorAction::Train(TrainAction::WaitForParameterServer { + timeout: short_idle, + })) + } else { + Ok(ExecutorAction::Train(TrainAction::SendUpdate { + target: Reference::Peers { + // Selecting a single PS to avoid that workers send updates to multiple PS + peers: vec![parameter_servers[0]], + strategy: SelectionStrategy::One, + resource: None, + }, + weight: peer_contribution as f32 / round_state.lock().await.projected_taret as f32, + })) + } +} + /// Handle common transitions in Idle and BatchCompleted #[allow(clippy::too_many_arguments)] -async fn handle_common_transition( +async fn execute_or_update( training_state: Arc>, + round_state: Arc>, batch_sizer: BatchSizer, multi_batch_size: u32, snapshot: Vec, @@ -153,12 +181,9 @@ where .map(|w| (batch_sizer)(&w.resources)) .collect(); - let (should_update, projected_target, batches) = if update_target <= count { - (true, count, 0) - } else if !snapshot.is_empty() - && batch_sizes.iter().all(|&b| b > 0) - && stats.iter().all(|&s| s > 0 && s < u64::MAX) - { + let (should_update, batches) = if update_target <= count { + (true, 0) + } else if batch_sizes.iter().all(|&b| b > 0) && stats.iter().all(|&s| s > 0 && s < u64::MAX) { let (time, cnt, projection, capped) = S::project( &progress, &batch_sizes, @@ -167,7 +192,7 @@ where multi_batch_size, ); - tracing::debug!( + tracing::info!( time = %time, count = %cnt, peer = %peer_id, @@ -176,36 +201,78 @@ where projection, update_target.saturating_sub(count) ); + + { + let projected_target = count.saturating_add(cnt.unsigned_abs()); + let mut state = round_state.lock().await; + if state.projected_taret != projected_target { + state.projected_taret = projected_target; + } + }; + ( - cnt <= 0 - && !capped - && peer_position < projection.len() - && projection[peer_position] == 0, - count.saturating_add(cnt.unsigned_abs()), + cnt <= 0 && !capped && projection[peer_position] == 0, projection[peer_position], ) } else { - (false, count, multi_batch_size) + { + let mut state = round_state.lock().await; + if state.projected_taret != update_target { + state.projected_taret = update_target; + } + }; + (false, multi_batch_size) }; if !should_update { Ok(ExecutorAction::Train(TrainAction::ExecuteBatch { batches })) - } else if parameter_servers.is_empty() { - // NOTE: If we need to send an update but there are no parameter servers, - // we must wait (idle) until one becomes available. + } else { + handle_send_update( + parameter_servers, + round_state, + peer_contribution, + short_idle, + ) + .await + } +} + +// Handle training finilaization transitions +#[allow(clippy::too_many_arguments)] +async fn finalize_training( + round_state: Arc>, + peer_id: PeerId, + cancel: CancellationToken, + short_idle: SystemTime, + worker_pool: PoolWithWorkerPropertiesHandle, + push_destination: Arc>, +) -> Result { + let mut state = round_state.lock().await; + let push_assigned = worker_pool.properties().iter().any(|w| w.state.is_pusher); + + // Model allready pushed -> Terminate + if state.push_done { + cancel.cancel(); + Ok(ExecutorAction::Train(TrainAction::Terminate)) + } else if !push_assigned { + // Assign self as pusher + worker_pool.update_state(&peer_id, |s| s.is_pusher = true); + + if let Some(destination) = push_destination.as_ref().as_ref() { + Ok(ExecutorAction::Train(TrainAction::PushToHub { + repository: destination.repository.clone(), + token: destination.token.clone(), + })) + } else { + // Model should not be pushed + state.push_done = true; + Ok(ExecutorAction::Train(TrainAction::Terminate)) + } + } else { + // Wait until model is pushed. Ok(ExecutorAction::Train(TrainAction::Idle { timeout: short_idle, })) - } else { - Ok(ExecutorAction::Train(TrainAction::SendUpdate { - target: Reference::Peers { - // Selecting a single PS to avoid that workers send updates to multiple PS - peers: vec![parameter_servers[0]], - strategy: SelectionStrategy::One, - resource: None, - }, - weight: peer_contribution as f32 / projected_target as f32, - })) } } @@ -219,7 +286,7 @@ async fn schedule( training_state: Arc>, batch_sizer: BatchSizer, multi_batch_size: u32, - push_destination: Arc>, + push_destination: Arc>, start: std::time::Instant, request: (PeerId, action::ActionRequest), cancel: CancellationToken, @@ -233,11 +300,12 @@ where tracing::debug!(%peer_id, ?status, %job_id, "Received action request"); let now = SystemTime::now(); - // Rimeouts sized for ~100ms RTT with generous margins. + // Timeouts sized for ~100ms RTT with generous margins. let short_idle = now + Duration::from_millis(500); let wait_model = now + Duration::from_secs(1); - let long_io = now + Duration::from_secs(60); let ps_broadcast_idle = now + Duration::from_secs(5); + let short_io = now + Duration::from_secs(10); + let long_io = now + Duration::from_secs(60); let since_start = start.elapsed().as_millis() as u64; // NOTE: We rely on Pool::members() being oldest-first ordered by join time. @@ -247,11 +315,18 @@ where let next_action = match status { ExecutorStatus::Train(train) => match train { TrainStatus::Joined => { - let state = round_state.lock().await; - if state.round == 0 { - ExecutorAction::Train(TrainAction::Idle { - timeout: short_idle, - }) + if round_state.lock().await.round == 0 { + execute_or_update::( + training_state, + round_state, + batch_sizer, + multi_batch_size, + worker_pool.properties(), + peer_id, + parameter_servers, + short_idle, + ) + .await? } else { worker_pool.update_state(&peer_id, |s| s.waiting_for_model = true); ExecutorAction::Train(TrainAction::WaitForModel { @@ -283,150 +358,109 @@ where } } TrainStatus::ReceivedModel => { - // Lazy transition to other state - ExecutorAction::Train(TrainAction::Idle { timeout: now }) + // Participate in training + execute_or_update::( + training_state, + round_state, + batch_sizer, + multi_batch_size, + worker_pool.properties(), + peer_id, + parameter_servers, + short_idle, + ) + .await? } TrainStatus::SentModel => { - // Lazy transition to other state - ExecutorAction::Train(TrainAction::Idle { timeout: now }) + // Participate in training + execute_or_update::( + training_state, + round_state, + batch_sizer, + multi_batch_size, + worker_pool.properties(), + peer_id, + parameter_servers, + short_idle, + ) + .await? } TrainStatus::Idle => { - let mut state = round_state.lock().await; - if !state.training_complete { - let snapshot = worker_pool.properties(); - - // Check if peer has applied update or sent update - let (has_applied_update, has_sent_update, _applied_final_update) = snapshot - .iter() - .find(|w| w.peer_id == peer_id) - .map(|w| { - ( - w.state.applied_update, - w.state.sent_update, - w.state.applied_final_update, - ) - }) - .unwrap_or((false, false, false)); - - if state.aggregated_updates && !has_applied_update { - ExecutorAction::Train(TrainAction::ApplyUpdate { - source: Reference::Peers { - peers: parameter_servers, - strategy: SelectionStrategy::All, - resource: None, - }, - timeout: now + Duration::from_secs(10), - }) - } else if has_sent_update { - ExecutorAction::Train(TrainAction::Idle { - timeout: short_idle, - }) - } else { - handle_common_transition::( - training_state, - batch_sizer, - multi_batch_size, - snapshot, - peer_id, - parameter_servers, - short_idle, - ) - .await? - } - } else if state.push_done { - cancel.cancel(); - ExecutorAction::Train(TrainAction::Terminate) + // Training in progress? + if !round_state.lock().await.training_complete { + execute_or_update::( + training_state, + round_state, + batch_sizer, + multi_batch_size, + worker_pool.properties(), + peer_id, + parameter_servers, + short_idle, + ) + .await? } else { - let snapshot = worker_pool.properties(); - let has_push_assignment = snapshot.iter().any(|w| w.state.is_pusher); - let (applied_final_update, _am_pusher, has_applied_update) = snapshot - .iter() - .find(|w| w.peer_id == peer_id) - .map(|w| { - ( - w.state.applied_final_update, - w.state.is_pusher, - w.state.applied_update, - ) - }) - .unwrap_or((false, false, false)); - - if state.aggregated_updates && !has_applied_update { - ExecutorAction::Train(TrainAction::ApplyUpdate { - source: Reference::Peers { - peers: parameter_servers, - strategy: SelectionStrategy::All, - resource: None, - }, - timeout: now + Duration::from_secs(10), - }) - } else if !has_push_assignment && applied_final_update { - // Assign self as pusher - worker_pool.update_state(&peer_id, |s| s.is_pusher = true); - - if let Some(destination) = push_destination.as_ref().as_ref() { - ExecutorAction::Train(TrainAction::PushToHub { - repository: destination.repository.clone(), - token: destination.token.clone(), - }) - } else { - // Should not occur due to guard above. - state.push_done = true; - ExecutorAction::Train(TrainAction::Terminate) - } - } else { - ExecutorAction::Train(TrainAction::Idle { - timeout: short_idle, - }) - } + finalize_training( + round_state, + peer_id, + cancel.clone(), + short_idle, + worker_pool, + push_destination, + ) + .await? } } - TrainStatus::BatchCompleted { batch_size } => { + TrainStatus::BatchCompleted { + batch_size, + batches, + } => { + // START: Update statistics and state worker_pool.update_statistics(&peer_id, |stats, last_updated| { if *last_updated > 0 { - stats.update(since_start.saturating_sub(*last_updated)); + stats.update(since_start.saturating_sub(*last_updated), batches as u64); } *last_updated = since_start; }); - let snapshot = worker_pool.properties(); - { let mut training = training_state.lock().await; - training.record_batch(batch_size); + training.record_batch(batch_size * batches); }; - // Update per-worker samples worker_pool.update_state(&peer_id, |s| { - s.samples_processed = s.samples_processed.saturating_add(batch_size); + s.samples_processed = s.samples_processed.saturating_add(batch_size * batches); }); + // END: Update statistics and state + + execute_or_update::( + training_state, + round_state, + batch_sizer, + multi_batch_size, + worker_pool.properties(), + peer_id, + parameter_servers, + short_idle, + ) + .await? + } + TrainStatus::WaitedForParameterServer => { + // Get peer contribution from pool + let peer_contribution = worker_pool + .properties() + .iter() + .find(|w| w.peer_id == peer_id) + .map(|w| w.state.samples_processed) + .unwrap_or(0); - let (training_complete, sent_update) = { - let state = round_state.lock().await; - let sent = snapshot - .iter() - .find(|w| w.peer_id == peer_id) - .map(|w| w.state.sent_update) - .unwrap_or(false); - (state.training_complete, sent) - }; - - if training_complete || sent_update { - ExecutorAction::Train(TrainAction::Idle { - timeout: short_idle, - }) - } else { - handle_common_transition::( - training_state, - batch_sizer, - multi_batch_size, - snapshot, - peer_id, - parameter_servers, - short_idle, - ) - .await? - } + handle_send_update( + parameter_servers, + round_state, + peer_contribution, + short_idle, + ) + .await? } TrainStatus::SentUpdate { mut metrics, .. } => { let snapshot = worker_pool.properties(); @@ -482,56 +516,90 @@ where "Worker reported SentUpdate; recorded for round" ); - ExecutorAction::Train(TrainAction::Idle { + ExecutorAction::Train(TrainAction::WaitForUpdate { timeout: short_idle, }) } + TrainStatus::WaitedForUpdate => { + if round_state.lock().await.aggregated_updates { + ExecutorAction::Train(TrainAction::ApplyUpdate { + source: Reference::Peers { + peers: parameter_servers, + strategy: SelectionStrategy::All, + resource: None, + }, + timeout: short_io, + }) + } else { + ExecutorAction::Train(TrainAction::WaitForUpdate { + timeout: short_idle, + }) + } + } TrainStatus::AppliedUpdate => { - let training_complete = { - worker_pool.update_state(&peer_id, |s| s.applied_update = true); + worker_pool.update_state(&peer_id, |s| s.applied_update = true); - let state = round_state.lock().await; - if state.training_complete { - worker_pool.update_state(&peer_id, |s| s.applied_final_update = true); - true - } else { - false - } - }; + // Find a worker waiting for model + let snapshot = worker_pool.properties(); + let waiting_worker = snapshot + .iter() + .find(|w| w.state.waiting_for_model) + .map(|w| w.peer_id); - if training_complete { - ExecutorAction::Train(TrainAction::Idle { - timeout: now + Duration::from_millis(500), + if round_state.lock().await.training_complete { + finalize_training( + round_state, + peer_id, + cancel, + short_idle, + worker_pool, + push_destination, + ) + .await? + } else if let Some(update_worker) = waiting_worker { + // Mark them as not waiting and set receiving from + worker_pool.update_state(&update_worker, |s| { + s.waiting_for_model = false; + s.receiving_from = Some(peer_id); + }); + + ExecutorAction::Train(TrainAction::SendModel { + target: Reference::Peers { + peers: vec![update_worker], + strategy: SelectionStrategy::One, + resource: None, + }, }) } else { - // Find a worker waiting for model - let snapshot = worker_pool.properties(); - let waiting_worker = snapshot - .iter() - .find(|w| w.state.waiting_for_model) - .map(|w| w.peer_id); - - if let Some(update_worker) = waiting_worker { - // Mark them as not waiting and set receiving from - worker_pool.update_state(&update_worker, |s| { - s.waiting_for_model = false; - s.receiving_from = Some(peer_id); - }); - - ExecutorAction::Train(TrainAction::SendModel { - target: Reference::Peers { - peers: vec![update_worker], - strategy: SelectionStrategy::One, - resource: None, - }, - }) - } else { - // We can either move through idle or expect that the parameters are tuned - // s.t., its okay to execute a multi batch in the first round. - ExecutorAction::Train(TrainAction::ExecuteBatch { - batches: multi_batch_size, - }) - } + ExecutorAction::Train(TrainAction::WaitForNextRound { + timeout: short_idle, + }) + } + } + TrainStatus::WaitedForNextRound => { + // We are in the next round if sent_update is reset + if worker_pool + .properties() + .iter() + .find(|w| w.peer_id == peer_id) + .map(|w| w.state.sent_update) + .unwrap_or(true) + { + ExecutorAction::Train(TrainAction::WaitForNextRound { + timeout: short_idle, + }) + } else { + execute_or_update::( + training_state, + round_state, + batch_sizer, + multi_batch_size, + worker_pool.properties(), + peer_id, + parameter_servers, + short_idle, + ) + .await? } } TrainStatus::PushedToHub => { @@ -580,7 +648,7 @@ where let all_applied = worker_pool .properties() .iter() - .all(|w| w.state.applied_final_update); + .all(|w| w.state.applied_update); if all_applied { ExecutorAction::Aggregate(AggregateAction::Terminate) @@ -820,7 +888,7 @@ impl BatchScheduler { grace: Duration, samples_between_updates: u32, update_rounds: u32, - push_destination: Option, + push_destination: Option, batch_sizer: BatchSizer, multi_batch_size: u32, cancel: CancellationToken, @@ -850,6 +918,7 @@ impl BatchScheduler { training_complete: false, push_done: false, aggregated_updates: false, + projected_taret: 0, })); let training_state = Arc::new(Mutex::new(TrainingState::new(samples_between_updates))); network @@ -939,7 +1008,7 @@ mod batch_scheduler_tests { pool::{ Pool, PoolConfig, PoolHandle, PoolWithWorkerProperties, PoolWithWorkerPropertiesHandle, }, - scheduler_config::{ModelDestiantion, PriceRange}, + scheduler_config::{ModelDestination, PriceRange}, simulation::BasicSimulation, statistics::{RunningMean, RuntimeStatistic}, worker::{TestWorkerBuilder, Worker}, @@ -966,7 +1035,7 @@ mod batch_scheduler_tests { } impl RuntimeStatistic for TestStat { - fn update(&mut self, time: u64) { + fn update(&mut self, time: u64, _count: u64) { let delta = time.saturating_sub(self.last_updated); self.last_updated = time; self.value = delta.max(1); @@ -1135,7 +1204,10 @@ mod batch_scheduler_tests { PeerId::random(), ActionRequest { job_id: Uuid::new_v4(), - status: ExecutorStatus::Train(TrainStatus::BatchCompleted { batch_size: 4 }), + status: ExecutorStatus::Train(TrainStatus::BatchCompleted { + batch_size: 4, + batches: 1, + }), }, ), token.clone(), @@ -1203,7 +1275,10 @@ mod batch_scheduler_tests { PeerId::random(), ActionRequest { job_id: Uuid::new_v4(), - status: ExecutorStatus::Train(TrainStatus::BatchCompleted { batch_size: 10 }), + status: ExecutorStatus::Train(TrainStatus::BatchCompleted { + batch_size: 10, + batches: 1, + }), }, ), token.clone(), @@ -1362,10 +1437,11 @@ mod batch_scheduler_tests { update_rounds: 2, training_complete: false, push_done: false, + projected_taret: 0, })); let training_state = Arc::new(tokio::sync::Mutex::new(TrainingState::new(1))); let batch_sizer = Arc::new(|_: &Resources| 1u32); - let push_destination = Arc::new(Some(ModelDestiantion { + let push_destination = Arc::new(Some(ModelDestination { repository: "hf/repo".to_string(), token: "token".to_string(), })); @@ -1386,7 +1462,7 @@ mod batch_scheduler_tests { round_state: Arc>, training_state: Arc>, batch_sizer: Arc u32 + Send + Sync>, - push_destination: Arc>, + push_destination: Arc>, start: Instant, tx: Sender<(PeerId, Metrics)>, token: CancellationToken, @@ -1419,7 +1495,10 @@ mod batch_scheduler_tests { // Round 0 Step { peer: w2_id, - status: ExecutorStatus::Train(TrainStatus::BatchCompleted { batch_size: 1 }), + status: ExecutorStatus::Train(TrainStatus::BatchCompleted { + batch_size: 1, + batches: 1, + }), check: Box::new(|resp| match resp { ExecutorAction::Train(TrainAction::ExecuteBatch { batches: 1 }) | ExecutorAction::Train(TrainAction::SendUpdate { .. }) => {} @@ -1428,7 +1507,10 @@ mod batch_scheduler_tests { }, Step { peer: w2_id, - status: ExecutorStatus::Train(TrainStatus::BatchCompleted { batch_size: 1 }), + status: ExecutorStatus::Train(TrainStatus::BatchCompleted { + batch_size: 1, + batches: 1, + }), check: Box::new(move |resp| match resp { ExecutorAction::Train(TrainAction::SendUpdate { target: Reference::Peers { peers, .. }, @@ -1452,7 +1534,10 @@ mod batch_scheduler_tests { }, Step { peer: w1_id, - status: ExecutorStatus::Train(TrainStatus::BatchCompleted { batch_size: 1 }), + status: ExecutorStatus::Train(TrainStatus::BatchCompleted { + batch_size: 1, + batches: 1, + }), check: Box::new(move |resp| match resp { ExecutorAction::Train(TrainAction::SendUpdate { target: Reference::Peers { peers, .. }, @@ -1511,7 +1596,10 @@ mod batch_scheduler_tests { // Round 1 Step { peer: w1_id, - status: ExecutorStatus::Train(TrainStatus::BatchCompleted { batch_size: 1 }), + status: ExecutorStatus::Train(TrainStatus::BatchCompleted { + batch_size: 1, + batches: 1, + }), check: Box::new(|resp| match resp { ExecutorAction::Train(TrainAction::ExecuteBatch { batches: 1 }) | ExecutorAction::Train(TrainAction::SendUpdate { .. }) => {} @@ -1520,7 +1608,10 @@ mod batch_scheduler_tests { }, Step { peer: w1_id, - status: ExecutorStatus::Train(TrainStatus::BatchCompleted { batch_size: 1 }), + status: ExecutorStatus::Train(TrainStatus::BatchCompleted { + batch_size: 1, + batches: 1, + }), check: Box::new(move |resp| match resp { ExecutorAction::Train(TrainAction::SendUpdate { target: Reference::Peers { peers, .. }, @@ -1544,7 +1635,10 @@ mod batch_scheduler_tests { }, Step { peer: w2_id, - status: ExecutorStatus::Train(TrainStatus::BatchCompleted { batch_size: 1 }), + status: ExecutorStatus::Train(TrainStatus::BatchCompleted { + batch_size: 1, + batches: 1, + }), check: Box::new(move |resp| match resp { ExecutorAction::Train(TrainAction::SendUpdate { target: Reference::Peers { peers, .. }, @@ -1700,7 +1794,7 @@ mod batch_scheduler_tests { "only one worker should push the final model" ); assert!( - snapshot.iter().all(|w| w.state.applied_final_update), + snapshot.iter().all(|w| w.state.applied_update), "both workers should have applied the final update" ); let state = round_state.lock().await; @@ -1784,6 +1878,7 @@ mod batch_scheduler_tests { update_rounds: 3, training_complete: false, push_done: false, + projected_taret: 0, })); let training_state = Arc::new(tokio::sync::Mutex::new(TrainingState::new(1))); let batch_sizer = Arc::new(|_: &Resources| 1u32); @@ -1805,7 +1900,7 @@ mod batch_scheduler_tests { round_state: Arc>, training_state: Arc>, batch_sizer: Arc u32 + Send + Sync>, - push_destination: Arc>, + push_destination: Arc>, start: Instant, tx: Sender<(PeerId, Metrics)>, token: CancellationToken, @@ -1838,7 +1933,10 @@ mod batch_scheduler_tests { // Round 0 Step { peer: w1_id, - status: ExecutorStatus::Train(TrainStatus::BatchCompleted { batch_size: 1 }), + status: ExecutorStatus::Train(TrainStatus::BatchCompleted { + batch_size: 1, + batches: 1, + }), check: Box::new(move |resp| match resp { ExecutorAction::Train(TrainAction::SendUpdate { target: Reference::Peers { peers, .. }, @@ -1908,7 +2006,10 @@ mod batch_scheduler_tests { // Round 1 with w1 producing the update Step { peer: w1_id, - status: ExecutorStatus::Train(TrainStatus::BatchCompleted { batch_size: 1 }), + status: ExecutorStatus::Train(TrainStatus::BatchCompleted { + batch_size: 1, + batches: 1, + }), check: Box::new(move |resp| match resp { ExecutorAction::Train(TrainAction::SendUpdate { target: Reference::Peers { peers, .. }, @@ -2179,6 +2280,7 @@ mod batch_scheduler_tests { training_complete: false, push_done: false, aggregated_updates: false, + projected_taret: 0, })); let training_state = std::sync::Arc::new(tokio::sync::Mutex::new(TrainingState::new(800))); let batch_sizer = std::sync::Arc::new(|resources: &Resources| resources.gpu() as u32); @@ -2204,7 +2306,10 @@ mod batch_scheduler_tests { let steps = vec![ Step::new( w3_id, - ExecutorStatus::Train(TrainStatus::BatchCompleted { batch_size: 50 }), + ExecutorStatus::Train(TrainStatus::BatchCompleted { + batch_size: 50, + batches: 1, + }), hypha_messages::action::ExecutorAction::Train(TrainAction::ExecuteBatch { batches: 1, }), @@ -2212,7 +2317,10 @@ mod batch_scheduler_tests { ), Step::new( w2_id, - ExecutorStatus::Train(TrainStatus::BatchCompleted { batch_size: 100 }), + ExecutorStatus::Train(TrainStatus::BatchCompleted { + batch_size: 100, + batches: 1, + }), hypha_messages::action::ExecutorAction::Train(TrainAction::ExecuteBatch { batches: 1, }), @@ -2220,7 +2328,10 @@ mod batch_scheduler_tests { ), Step::new( w1_id, - ExecutorStatus::Train(TrainStatus::BatchCompleted { batch_size: 150 }), + ExecutorStatus::Train(TrainStatus::BatchCompleted { + batch_size: 150, + batches: 1, + }), hypha_messages::action::ExecutorAction::Train(TrainAction::ExecuteBatch { batches: 1, }), @@ -2228,7 +2339,10 @@ mod batch_scheduler_tests { ), Step::new( w3_id, - ExecutorStatus::Train(TrainStatus::BatchCompleted { batch_size: 50 }), + ExecutorStatus::Train(TrainStatus::BatchCompleted { + batch_size: 50, + batches: 1, + }), hypha_messages::action::ExecutorAction::Train(TrainAction::ExecuteBatch { batches: 1, }), @@ -2236,7 +2350,10 @@ mod batch_scheduler_tests { ), Step::new( w3_id, - ExecutorStatus::Train(TrainStatus::BatchCompleted { batch_size: 50 }), + ExecutorStatus::Train(TrainStatus::BatchCompleted { + batch_size: 50, + batches: 1, + }), hypha_messages::action::ExecutorAction::Train(TrainAction::ExecuteBatch { batches: 1, }), @@ -2244,7 +2361,10 @@ mod batch_scheduler_tests { ), Step::new( w2_id, - ExecutorStatus::Train(TrainStatus::BatchCompleted { batch_size: 100 }), + ExecutorStatus::Train(TrainStatus::BatchCompleted { + batch_size: 100, + batches: 1, + }), hypha_messages::action::ExecutorAction::Train(TrainAction::ExecuteBatch { batches: 1, }), @@ -2252,7 +2372,10 @@ mod batch_scheduler_tests { ), Step::new( w1_id, - ExecutorStatus::Train(TrainStatus::BatchCompleted { batch_size: 150 }), + ExecutorStatus::Train(TrainStatus::BatchCompleted { + batch_size: 150, + batches: 1, + }), hypha_messages::action::ExecutorAction::Train(TrainAction::ExecuteBatch { batches: 1, }), @@ -2260,7 +2383,10 @@ mod batch_scheduler_tests { ), Step::new( w3_id, - ExecutorStatus::Train(TrainStatus::BatchCompleted { batch_size: 50 }), + ExecutorStatus::Train(TrainStatus::BatchCompleted { + batch_size: 50, + batches: 1, + }), hypha_messages::action::ExecutorAction::Train(TrainAction::SendUpdate { target: Reference::Peers { peers: vec![ps_id], @@ -2273,7 +2399,10 @@ mod batch_scheduler_tests { ), Step::new( w2_id, - ExecutorStatus::Train(TrainStatus::BatchCompleted { batch_size: 100 }), + ExecutorStatus::Train(TrainStatus::BatchCompleted { + batch_size: 100, + batches: 1, + }), hypha_messages::action::ExecutorAction::Train(TrainAction::SendUpdate { target: Reference::Peers { peers: vec![ps_id], diff --git a/crates/scheduler/src/statistics.rs b/crates/scheduler/src/statistics.rs index 7951ff32..5ef04a27 100644 --- a/crates/scheduler/src/statistics.rs +++ b/crates/scheduler/src/statistics.rs @@ -1,5 +1,5 @@ pub trait RuntimeStatistic: Send + Sync + Default { - fn update(&mut self, time: u64); + fn update(&mut self, time: u64, count: u64); fn value(&self) -> u64; } @@ -26,15 +26,17 @@ impl Default for RunningMean { } impl RuntimeStatistic for RunningMean { - fn update(&mut self, time: u64) { - if self.samples == 0 { - self.running_mean = time; - self.samples = 1; - } else { - self.samples += 1; - self.running_mean = (self.running_mean as i64 - + (time as i64 - self.running_mean as i64) / self.samples as i64) - as u64; + fn update(&mut self, time: u64, count: u64) { + if count > 0 { + if self.samples == 0 { + self.running_mean = time / count; + self.samples = count; + } else { + self.samples += count; + self.running_mean = (self.running_mean as i64 + + ((time / count) as i64 - self.running_mean as i64) / self.samples as i64) + as u64; + } } } @@ -51,20 +53,41 @@ mod tests { fn running_mean_update() { let mut running_mean = RunningMean::new(); - running_mean.update(1050); + running_mean.update(1050, 1); assert_eq!(running_mean.samples, 1, "First update"); assert_eq!(running_mean.value(), 1050); - running_mean.update(1000); + running_mean.update(1000, 1); assert_eq!(running_mean.samples, 2, "Second update"); assert_eq!(running_mean.value(), 1025); - running_mean.update(1025); + running_mean.update(1025, 1); assert_eq!(running_mean.samples, 3, "Third update"); assert_eq!(running_mean.value(), 1025); - running_mean.update(2050); + running_mean.update(2050, 1); assert_eq!(running_mean.samples, 4, "Fourth update"); assert_eq!(running_mean.value(), 1281); } + + #[test] + fn running_mean_weighted_update() { + let mut running_mean = RunningMean::new(); + + running_mean.update(10, 1); + assert_eq!(running_mean.samples, 1, "First update"); + assert_eq!(running_mean.value(), 10); + + running_mean.update(20, 2); + assert_eq!(running_mean.samples, 3, "Second update"); + assert_eq!(running_mean.value(), 10); + + running_mean.update(30, 3); + assert_eq!(running_mean.samples, 6, "Third update"); + assert_eq!(running_mean.value(), 10); + + running_mean.update(40, 4); + assert_eq!(running_mean.samples, 10, "Fourth update"); + assert_eq!(running_mean.value(), 10); + } } diff --git a/crates/worker/src/executor/parameter_server.rs b/crates/worker/src/executor/parameter_server.rs index 1904a053..30060105 100644 --- a/crates/worker/src/executor/parameter_server.rs +++ b/crates/worker/src/executor/parameter_server.rs @@ -167,7 +167,7 @@ impl JobExecutor for ParameterServerExecutor { Ok(n) => { let _ = f.sync_all().await; let _ = fs::set_permissions(&file_path, Permissions::from_mode(0o600)).await; - tracing::debug!(peer_id = %peer, size = n, file = %file_path.display(), "Received update"); } + tracing::info!(peer_id = %peer, size = n, file = %file_path.display(), "Received update"); } Err(err) => { tracing::error!(error = %err, file = %file_path.display(), "Failed to write received update"); } diff --git a/executors/accelerate/src/hypha/accelerate_executor/training.py b/executors/accelerate/src/hypha/accelerate_executor/training.py index 9c295db2..0cc88ecc 100644 --- a/executors/accelerate/src/hypha/accelerate_executor/training.py +++ b/executors/accelerate/src/hypha/accelerate_executor/training.py @@ -125,8 +125,9 @@ def sleep_until_epoch_ms(target_ms: int) -> None: model = get_model(local_fetch_path, config["model"]["task"]) optimizer = get_adam(config["optimizer"], model.parameters()) scheduler = get_scheduler(config.get("scheduler"), optimizer) + batch_size = config["batch_size"] data_loader = torch.utils.data.DataLoader( - IterableStreamDataSet(args.socket, work_dir, local_fetch_path, config["batch_size"], config), + IterableStreamDataSet(args.socket, work_dir, local_fetch_path, batch_size, config), batch_size=None, pin_memory=True, num_workers=4, @@ -187,13 +188,17 @@ def sleep_until_epoch_ms(target_ms: int) -> None: optimizer.step() scheduler.step() if accelerator.is_main_process: - batch_size += next(iter(batch.values())).shape[0] loss_list.append(loss.detach().cpu().numpy()) if accelerator.is_main_process: current_status = { "executor": "train", - "details": {"state": "batch-completed", "batch_size": batch_size}, + "details": {"state": "batch-completed", "batch_size": batch_size, "batches": local_batches}, } + elif kind == "wait-for-parameter-server": + timeout_ms = system_time_to_epoch_ms(action.get("timeout")) + if timeout_ms is not None: + sleep_until_epoch_ms(timeout_ms) + current_status = {"executor": "train", "details": {"state": "waited-for-parameter-server"}} elif kind == "send-update": target = action.get("target") if target is None: @@ -250,6 +255,11 @@ def sleep_until_epoch_ms(target_ms: int) -> None: "message": str(exc), }, } + elif kind == "wait-for-update": + timeout_ms = system_time_to_epoch_ms(action.get("timeout")) + if timeout_ms is not None: + sleep_until_epoch_ms(timeout_ms) + current_status = {"executor": "train", "details": {"state": "waited-for-update"}} elif kind == "apply-update": source = action.get("source") if source is None: @@ -304,6 +314,11 @@ def sleep_until_epoch_ms(target_ms: int) -> None: "details": {"state": "applied-update"}, } epoch_counter += 1 + elif kind == "wait-for-next-round": + timeout_ms = system_time_to_epoch_ms(action.get("timeout")) + if timeout_ms is not None: + sleep_until_epoch_ms(timeout_ms) + current_status = {"executor": "train", "details": {"state": "waited-for-next-round"}} elif kind == "push-to-hub": repository = action.get("repository") token = action.get("token")