From 9b8855dd62dacb32a14b0e9a31dcbc8e59c0f266 Mon Sep 17 00:00:00 2001 From: zhengchenyu Date: Fri, 19 Dec 2025 12:13:49 +0800 Subject: [PATCH] Configure expected_replicas to avoid running tasks with unexpected replicas. --- src/lib.rs | 1 + src/lighthouse.rs | 466 +++++++++++++++++++++++++++++++++++++++++++++- src/manager.rs | 4 + 3 files changed, 467 insertions(+), 4 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 7291c09f..da3a8091 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -643,6 +643,7 @@ impl LighthouseServer { join_timeout_ms, quorum_tick_ms, heartbeat_timeout_ms, + expected_replicas: None, })) .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; diff --git a/src/lighthouse.rs b/src/lighthouse.rs index ea1aa3f9..c510e00e 100644 --- a/src/lighthouse.rs +++ b/src/lighthouse.rs @@ -128,6 +128,14 @@ pub struct LighthouseOpt { help = "How long to wait for a heartbeat before considering a replica dead." )] pub heartbeat_timeout_ms: u64, + + #[structopt( + long = "expected_replicas", + use_delimiter = true, + value_delimiter = ",", + help = "Expected number of replicas (comma-separated, strictly increasing, e.g., '1,2,4')" + )] + pub expected_replicas: Option>, } fn quorum_changed(a: &Vec, b: &Vec) -> bool { @@ -137,6 +145,80 @@ fn quorum_changed(a: &Vec, b: &Vec) -> bool { return a_ids != b_ids; } +// Find the expected replica count based on candidate participant count. +// Returns the largest value in expected_replicas that is <= current_count. +fn find_expected_replica_count(expected_replicas: &[u64], candidate_count: usize) -> Option { + expected_replicas + .iter() + .filter(|&&n| n <= candidate_count as u64) + .max() + .map(|&n| n as usize) +} + +// Apply expected_replicas truncation while preserving prev_quorum participants. +// When truncating, this function: +// 1. Preserves all prev_quorum participants (to avoid disrupting running training) +// 2. Fills remaining slots with new participants (sorted by smallest replica IDs) +// 3. Returns the truncated participant list +fn apply_expected_replicas_truncation( + mut participants: Vec, + expected_replicas: &Option>, + prev_quorum: &Option, +) -> Vec { + if let Some(ref expected) = expected_replicas { + if let Some(expected_count) = find_expected_replica_count(expected, participants.len()) { + if expected_count < participants.len() { + // If prev_quorum exists, preserve its participants + if let Some(ref prev) = prev_quorum { + let prev_replica_ids: HashSet<&String> = + prev.participants.iter().map(|p| &p.replica_id).collect(); + + // Separate into prev_quorum participants (already training) and new participants + let mut prev_participants: Vec = participants + .iter() + .filter(|p| prev_replica_ids.contains(&p.replica_id)) + .cloned() + .collect(); + + let mut new_participants: Vec = participants + .iter() + .filter(|p| !prev_replica_ids.contains(&p.replica_id)) + .cloned() + .collect(); + + // Calculate how many new participants we can add + // If prev_participants exceeds expected_count, we need to shrink prev_participants too + if prev_participants.len() > expected_count { + // Shrink case: truncate prev_participants to expected_count + // Keep participants with smallest replica IDs (consistent with expansion logic) + prev_participants.sort_by_key(|p| p.replica_id.clone()); + prev_participants.truncate(expected_count); + participants = prev_participants; + } else { + // Expansion or stable case: keep all prev_participants and add new ones if there's room + let remaining_slots = expected_count - prev_participants.len(); + + // Truncate new participants (they are already sorted by ID) + // Keep only the ones with smallest IDs + new_participants.truncate(remaining_slots); + + // Combine: all prev participants + selected new participants + participants = prev_participants; + participants.extend(new_participants); + + // Sort again to maintain consistent ordering + participants.sort_by_key(|p| p.replica_id.clone()); + } + } else { + // No prev_quorum, just truncate normally (keep smallest IDs) + participants.truncate(expected_count); + } + } + } + } + participants +} + // Checks whether the quorum is valid, the new quorum and an explanation for the state. fn quorum_compute( now: Instant, @@ -167,6 +249,8 @@ fn quorum_compute( .collect(); // Sort by replica ID to get a consistent ordering across runs. + // This ensures that when truncating based on expected_replicas, + // we keep participants with smaller IDs and remove those with larger IDs. candidate_participants.sort_by_key(|p| p.replica_id.clone()); let shrink_only = healthy_participants @@ -207,9 +291,27 @@ fn quorum_compute( .all(|prev_member| healthy_participants.contains_key(&prev_member.replica_id)); if is_fast_quorum { + // Apply expected_replicas constraint if set + // Preserve prev_quorum participants (already training) when truncating + let result_participants = apply_expected_replicas_truncation( + candidate_participants.clone(), + &opt.expected_replicas, + &state.prev_quorum, + ); + + let final_metadata = format!( + "[{}/{} participants selected][{}/{} participants healthy][{} heartbeating][shrink_only={}]", + result_participants.len(), + candidate_participants.len(), + healthy_participants.len(), + state.participants.len(), + healthy_replicas.len(), + shrink_only, + ); + return ( - Some(candidate_participants), - format!("Fast quorum found! {}", metadata), + Some(result_participants), + format!("Fast quorum found! {}", final_metadata), ); } } @@ -262,14 +364,51 @@ fn quorum_compute( ); } + // Apply expected_replicas constraint if set + // Preserve prev_quorum participants (already training) when truncating + let result_participants = apply_expected_replicas_truncation( + candidate_participants.clone(), + &opt.expected_replicas, + &state.prev_quorum, + ); + + let final_metadata = format!( + "[{}/{} participants selected][{}/{} participants healthy][{} heartbeating][shrink_only={}]", + result_participants.len(), + candidate_participants.len(), + healthy_participants.len(), + state.participants.len(), + healthy_replicas.len(), + shrink_only, + ); + ( - Some(candidate_participants), - format!("Valid quorum found {}", metadata), + Some(result_participants), + format!("Valid quorum found {}", final_metadata), ) } impl Lighthouse { pub async fn new(opt: LighthouseOpt) -> Result> { + // Validate expected_replicas constraints + if let Some(ref expected) = opt.expected_replicas { + // Check if strictly increasing + for i in 1..expected.len() { + if expected[i] <= expected[i - 1] { + return Err(anyhow!("expected_replicas must be strictly increasing")); + } + } + + // Check if first element equals min_replicas + if !expected.is_empty() && expected[0] != opt.min_replicas { + return Err(anyhow!( + "expected_replicas[0] ({}) must equal min_replicas ({})", + expected[0], + opt.min_replicas + )); + } + } + let listener = tokio::net::TcpListener::bind(&opt.bind).await?; let (tx, _) = broadcast::channel(16); @@ -632,6 +771,7 @@ mod tests { join_timeout_ms: 60 * 60 * 1000, // 1hr quorum_tick_ms: 10, heartbeat_timeout_ms: 5000, + expected_replicas: None, }; let mut state = State { @@ -711,6 +851,7 @@ mod tests { join_timeout_ms: 0, quorum_tick_ms: 10, heartbeat_timeout_ms: 5000, + expected_replicas: None, }; let mut state = State { @@ -797,6 +938,7 @@ mod tests { join_timeout_ms: 60 * 60 * 1000, // 1hr quorum_tick_ms: 10, heartbeat_timeout_ms: 5000, + expected_replicas: None, }; let mut state = State { @@ -887,6 +1029,7 @@ mod tests { join_timeout_ms: 60 * 60 * 1000, // 1hr quorum_tick_ms: 10, heartbeat_timeout_ms: 5000, + expected_replicas: None, }; let mut state = State { @@ -982,6 +1125,7 @@ mod tests { join_timeout_ms: 1, quorum_tick_ms: 10, heartbeat_timeout_ms: 5000, + expected_replicas: None, }; let lighthouse = Lighthouse::new(opt).await?; @@ -1028,6 +1172,7 @@ mod tests { join_timeout_ms: 60 * 60 * 1000, // 1hr quorum_tick_ms: 10, heartbeat_timeout_ms: 5000, + expected_replicas: None, }; let mut state = State { @@ -1138,6 +1283,7 @@ mod tests { join_timeout_ms: 1000, quorum_tick_ms: 10, heartbeat_timeout_ms: 5000, + expected_replicas: None, }; // Start the lighthouse service @@ -1251,6 +1397,7 @@ mod tests { join_timeout_ms: 1000, quorum_tick_ms: 10, heartbeat_timeout_ms: 5000, + expected_replicas: None, }; // Start the lighthouse service @@ -1295,4 +1442,315 @@ mod tests { lighthouse_task.abort(); Ok(()) } + + #[tokio::test] + async fn test_expected_replicas_truncate() -> Result<()> { + let opt = LighthouseOpt { + min_replicas: 1, + bind: "[::]:0".to_string(), + join_timeout_ms: 0, + quorum_tick_ms: 10, + heartbeat_timeout_ms: 5000, + expected_replicas: Some(vec![1, 2, 4]), + }; + + let mut state = State { + channel: broadcast::channel(16).0, + participants: HashMap::new(), + prev_quorum: None, + quorum_id: 0, + heartbeats: HashMap::new(), + }; + + let now = Instant::now(); + + // Add 3 participants: "a", "b", "c" + for id in ["a", "b", "c"] { + state.participants.insert( + id.to_string(), + QuorumMemberDetails { + joined: now, + member: QuorumMember { + replica_id: id.to_string(), + address: format!("addr_{}", id), + store_address: format!("store_{}", id), + step: 1, + world_size: 3, + shrink_only: false, + data: String::new(), + commit_failures: 0, + }, + }, + ); + state.heartbeats.insert(id.to_string(), now); + } + + // With 3 participants and expected_replicas=[1,2,4], should get 2 participants + let (quorum_met, reason) = quorum_compute(now, &state, &opt); + assert!(quorum_met.is_some(), "{}", reason); + + let participants = quorum_met.unwrap(); + assert_eq!( + participants.len(), + 2, + "Should have 2 participants (from expected_replicas)" + ); + + // Verify that we keep the participants with smaller IDs ("a" and "b") + assert_eq!( + participants[0].replica_id, "a", + "First participant should be 'a' (smallest ID)" + ); + assert_eq!( + participants[1].replica_id, "b", + "Second participant should be 'b' (second smallest ID)" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_expected_replicas_preserves_prev_quorum() -> Result<()> { + let opt = LighthouseOpt { + min_replicas: 1, + bind: "[::]:0".to_string(), + join_timeout_ms: 0, + quorum_tick_ms: 10, + heartbeat_timeout_ms: 5000, + expected_replicas: Some(vec![1, 2, 4]), + }; + + let mut state = State { + channel: broadcast::channel(16).0, + participants: HashMap::new(), + prev_quorum: None, + quorum_id: 1, + heartbeats: HashMap::new(), + }; + + let now = Instant::now(); + + // Setup prev_quorum with participants "b" and "d" (not sorted, to test that we preserve them) + state.prev_quorum = Some(Quorum { + quorum_id: 1, + participants: vec![ + QuorumMember { + replica_id: "b".to_string(), + address: "addr_b".to_string(), + store_address: "store_b".to_string(), + step: 1, + world_size: 2, + shrink_only: false, + data: String::new(), + commit_failures: 0, + }, + QuorumMember { + replica_id: "d".to_string(), + address: "addr_d".to_string(), + store_address: "store_d".to_string(), + step: 1, + world_size: 2, + shrink_only: false, + data: String::new(), + commit_failures: 0, + }, + ], + created: Some(SystemTime::now().into()), + }); + + // Add 4 participants: "a", "b", "c", "d" (all healthy) + // "b" and "d" are from prev_quorum, "a" and "c" are new + for id in ["a", "b", "c", "d"] { + state.participants.insert( + id.to_string(), + QuorumMemberDetails { + joined: now, + member: QuorumMember { + replica_id: id.to_string(), + address: format!("addr_{}", id), + store_address: format!("store_{}", id), + step: 1, + world_size: 4, + shrink_only: false, + data: String::new(), + commit_failures: 0, + }, + }, + ); + state.heartbeats.insert(id.to_string(), now); + } + + // With 4 participants and expected_replicas=[1,2,4], should get 4 participants + // But this is fast_quorum (all prev_quorum members are healthy) + // If we had 5 participants with expected_replicas=[1,2,4], we should keep 4 + // Let's add a 5th participant "e" + state.participants.insert( + "e".to_string(), + QuorumMemberDetails { + joined: now, + member: QuorumMember { + replica_id: "e".to_string(), + address: "addr_e".to_string(), + store_address: "store_e".to_string(), + step: 1, + world_size: 5, + shrink_only: false, + data: String::new(), + commit_failures: 0, + }, + }, + ); + state.heartbeats.insert("e".to_string(), now); + + // With 5 participants and expected_replicas=[1,2,4], should get 4 participants + let (quorum_met, reason) = quorum_compute(now, &state, &opt); + assert!(quorum_met.is_some(), "{}", reason); + + let participants = quorum_met.unwrap(); + assert_eq!( + participants.len(), + 4, + "Should have 4 participants (from expected_replicas)" + ); + + // Verify the sorted order: should be ["a", "b", "c", "d"], but "e" should be excluded + assert_eq!( + participants[0].replica_id, "a", + "First participant should be 'a' (sorted order)" + ); + assert_eq!( + participants[1].replica_id, "b", + "Second participant should be 'b' (sorted order)" + ); + assert_eq!( + participants[2].replica_id, "c", + "Third participant should be 'c' (sorted order)" + ); + assert_eq!( + participants[3].replica_id, "d", + "Fourth participant should be 'd' (sorted order)" + ); + Ok(()) + } + + #[tokio::test] + async fn test_expected_replicas_shrink() -> Result<()> { + // This test simulates the user's scenario: + // - Start with 4 replicas (world_size=4) + // - One replica fails, leaving 3 healthy + // - With expected_replicas=[1,2,4], should shrink to 2 (not stay at 3) + let opt = LighthouseOpt { + min_replicas: 1, + bind: "[::]:0".to_string(), + join_timeout_ms: 0, + quorum_tick_ms: 10, + heartbeat_timeout_ms: 5000, + expected_replicas: Some(vec![1, 2, 4]), + }; + + let mut state = State { + channel: broadcast::channel(16).0, + participants: HashMap::new(), + prev_quorum: None, + quorum_id: 1, + heartbeats: HashMap::new(), + }; + + let now = Instant::now(); + + // Setup prev_quorum with 4 participants (simulating initial quorum) + state.prev_quorum = Some(Quorum { + quorum_id: 1, + participants: vec![ + QuorumMember { + replica_id: "replica0".to_string(), + address: "addr0".to_string(), + store_address: "store0".to_string(), + step: 100, + world_size: 4, + shrink_only: false, + data: String::new(), + commit_failures: 0, + }, + QuorumMember { + replica_id: "replica1".to_string(), + address: "addr1".to_string(), + store_address: "store1".to_string(), + step: 100, + world_size: 4, + shrink_only: false, + data: String::new(), + commit_failures: 0, + }, + QuorumMember { + replica_id: "replica2".to_string(), + address: "addr2".to_string(), + store_address: "store2".to_string(), + step: 100, + world_size: 4, + shrink_only: false, + data: String::new(), + commit_failures: 0, + }, + QuorumMember { + replica_id: "replica3".to_string(), + address: "addr3".to_string(), + store_address: "store3".to_string(), + step: 100, + world_size: 4, + shrink_only: false, + data: String::new(), + commit_failures: 0, + }, + ], + created: Some(SystemTime::now().into()), + }); + + // Now only 3 replicas are healthy (replica3 failed) + for id in ["replica0", "replica1", "replica2"] { + state.participants.insert( + id.to_string(), + QuorumMemberDetails { + joined: now, + member: QuorumMember { + replica_id: id.to_string(), + address: format!("addr_{}", &id[7..]), // extract number from "replicaN" + store_address: format!("store_{}", &id[7..]), + step: 100, + world_size: 3, + shrink_only: false, + data: String::new(), + commit_failures: 0, + }, + }, + ); + state.heartbeats.insert(id.to_string(), now); + } + + // With 3 healthy participants and expected_replicas=[1,2,4], + // should truncate to 2 participants (not keep all 3) + let (quorum_met, reason) = quorum_compute(now, &state, &opt); + assert!(quorum_met.is_some(), "{}", reason); + + let participants = quorum_met.unwrap(); + assert_eq!( + participants.len(), + 2, + "Should have 2 participants (from expected_replicas), not 3. Reason: {}", + reason + ); + + // Should keep replica0 and replica1 (smallest IDs) + assert_eq!(participants[0].replica_id, "replica0"); + assert_eq!(participants[1].replica_id, "replica1"); + + // Verify the reason includes the truncation info + assert!( + reason.contains("2/3 participants selected"), + "Reason should show truncation: {}", + reason + ); + + Ok(()) + } } diff --git a/src/manager.rs b/src/manager.rs index 816e06ab..4f219472 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -661,6 +661,7 @@ mod tests { min_replicas: 1, quorum_tick_ms: 100, heartbeat_timeout_ms: 5000, + expected_replicas: None, }) .await?; let lighthouse_fut = tokio::spawn(lighthouse.clone().run()); @@ -709,6 +710,7 @@ mod tests { min_replicas: 1, quorum_tick_ms: 100, heartbeat_timeout_ms: 5000, + expected_replicas: None, }) .await?; let lighthouse_fut = tokio::spawn(lighthouse.clone().run()); @@ -765,6 +767,7 @@ mod tests { min_replicas: 2, quorum_tick_ms: 100, heartbeat_timeout_ms: 5000, + expected_replicas: None, }) .await?; let lighthouse_fut = tokio::spawn(lighthouse.clone().run()); @@ -838,6 +841,7 @@ mod tests { min_replicas: 1, quorum_tick_ms: 100, heartbeat_timeout_ms: 5000, + expected_replicas: None, }) .await?; let lighthouse_fut = tokio::spawn(lighthouse.clone().run());