diff --git a/easyfix-session/src/session.rs b/easyfix-session/src/session.rs index b6ca545..cc082f6 100644 --- a/easyfix-session/src/session.rs +++ b/easyfix-session/src/session.rs @@ -625,30 +625,62 @@ impl Session { } #[instrument(level = "trace", skip_all)] - fn resend_range(&self, state: &mut State, begin_seq_num: SeqNum, mut end_seq_num: SeqNum) { + async fn resend_range(&self, begin_seq_num: SeqNum, mut end_seq_num: SeqNum) { + /// Maximum number of messages to process before yielding to other tasks. + /// This prevents resend storms from starving other connections. + const RESEND_BATCH_SIZE: usize = 64; + info!("resend range: ({begin_seq_num}, {end_seq_num})"); - let next_sender_msg_seq_num = state.next_sender_msg_seq_num(); - if end_seq_num == 0 || end_seq_num >= next_sender_msg_seq_num { - end_seq_num = next_sender_msg_seq_num - 1; - info!("adjust end_seq_num to {end_seq_num}"); - } - // Just do a gap fill when messages aren't persisted - if !self.session_settings.persist { + // Collect messages and required state within a short-lived borrow scope + let (messages, persist) = { + let mut state = self.state.borrow_mut(); let next_sender_msg_seq_num = state.next_sender_msg_seq_num(); - end_seq_num += 1; - if end_seq_num > next_sender_msg_seq_num { - end_seq_num = next_sender_msg_seq_num; + if end_seq_num == 0 || end_seq_num >= next_sender_msg_seq_num { + end_seq_num = next_sender_msg_seq_num - 1; + info!("adjust end_seq_num to {end_seq_num}"); + } + + // Just do a gap fill when messages aren't persisted + if !self.session_settings.persist { + let next_sender_msg_seq_num = state.next_sender_msg_seq_num(); + let mut gap_end = end_seq_num + 1; + if gap_end > next_sender_msg_seq_num { + gap_end = next_sender_msg_seq_num; + } + drop(state); + self.send_sequence_reset(begin_seq_num, gap_end); + return; } - self.send_sequence_reset(begin_seq_num, end_seq_num); + + // Collect messages to release the borrow on state before yielding. + // This is necessary because we cannot hold RefCell borrow across await points. + info!("fetch messages range from {begin_seq_num} to {end_seq_num}"); + let messages: Vec> = state + .fetch_range(begin_seq_num..=end_seq_num) + .map(|msg| msg.to_vec()) + .collect(); + + (messages, self.session_settings.persist) + }; + // State borrow is now dropped + + if !persist { return; } let mut gap_fill_range = None; - info!("fetch messages range from {begin_seq_num} to {end_seq_num}"); - for msg_str in state.fetch_range(begin_seq_num..=end_seq_num) { - // TODO: log error! and resend as gap fill instead of unwrap - let mut msg = match FixtMessage::from_bytes(msg_str) { + let mut processed_count = 0usize; + + for msg_bytes in messages { + // Yield periodically to prevent resend storms from starving other connections + if processed_count > 0 && processed_count % RESEND_BATCH_SIZE == 0 { + trace!("yielding after processing {processed_count} resend messages"); + tokio::task::yield_now().await; + } + processed_count += 1; + + let mut msg = match FixtMessage::from_bytes(&msg_bytes) { Ok(msg) => msg, Err(err) => { error!(%err, "Failed to decode message bytes"); @@ -759,9 +791,9 @@ impl Session { info!("Received ResendRequest FROM: {begin_seq_no} TO: {end_seq_no}"); - let mut state = self.state.borrow_mut(); + self.resend_range(begin_seq_no, end_seq_no).await; - self.resend_range(&mut state, begin_seq_no, end_seq_no); + let mut state = self.state.borrow_mut(); if Self::is_target_too_high(&state, msg_seq_num) { // XXX: This message will be ignored during queued messages @@ -1049,28 +1081,18 @@ impl Session { { // is the 789 lower (we checked for higher previously) than our next message after receiving the logon if next_expected_msg_seq_num != next_sender_msg_num_at_logon_received { - let mut end_seq_no = next_sender_msg_num_at_logon_received; - - // TODO: self.resend_range() will handle this !!! - if !self.session_settings.persist { - end_seq_no += 1; - let next = state.next_sender_msg_seq_num(); - if end_seq_no > next { - end_seq_no = next; - } - info!( - "Received implicit ResendRequest via Logon FROM: {next_expected_msg_seq_num}, \ - TO: {next_sender_msg_num_at_logon_received} will be reset" - ); - self.send_sequence_reset(next_expected_msg_seq_num, end_seq_no); - } else { - // resend missed messages - info!( - "Received implicit ResendRequest via Logon FROM: {next_expected_msg_seq_num} \ - TO: {next_sender_msg_num_at_logon_received} will be resent" - ); - self.resend_range(&mut state, next_expected_msg_seq_num, end_seq_no) - } + let end_seq_no = next_sender_msg_num_at_logon_received; + + // resend missed messages (handles both persist and non-persist cases) + info!( + "Received implicit ResendRequest via Logon FROM: {next_expected_msg_seq_num} \ + TO: {next_sender_msg_num_at_logon_received}" + ); + // Drop state borrow before async call + drop(state); + self.resend_range(next_expected_msg_seq_num, end_seq_no).await; + // Re-borrow state for subsequent code + state = self.state.borrow_mut(); } } @@ -1410,3 +1432,239 @@ impl Session { ) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::application::events_channel; + use crate::messages_storage::InMemoryStorage; + use chrono::NaiveTime; + use easyfix_messages::fields::FixString; + use std::cell::RefCell; + use std::rc::Rc; + use tokio::sync::mpsc; + + fn test_settings() -> Settings { + Settings { + sender_comp_id: FixString::from_ascii_lossy(b"SENDER".to_vec()), + sender_sub_id: None, + heartbeat_interval: Some(30), + auto_disconnect_after_no_logon_received: Duration::from_secs(10), + auto_disconnect_after_no_heartbeat: 3, + auto_disconnect_after_no_logout: Duration::from_secs(5), + } + } + + fn test_session_settings() -> SessionSettings { + SessionSettings { + session_id: SessionId::new( + FixString::from_ascii_lossy(b"FIX.4.4".to_vec()), + FixString::from_ascii_lossy(b"SENDER".to_vec()), + FixString::from_ascii_lossy(b"TARGET".to_vec()), + ), + session_time: NaiveTime::from_hms_opt(0, 0, 0).unwrap() + ..=NaiveTime::from_hms_opt(23, 59, 59).unwrap(), + logon_time: NaiveTime::from_hms_opt(0, 0, 0).unwrap() + ..=NaiveTime::from_hms_opt(23, 59, 59).unwrap(), + send_redundant_resend_requests: false, + check_comp_id: true, + max_latency: None, + reset_on_logon: false, + reset_on_logout: false, + reset_on_disconnect: false, + refresh_on_logon: false, + sender_default_appl_ver_id: FixString::from_ascii_lossy(b"9".to_vec()), + target_default_appl_ver_id: FixString::from_ascii_lossy(b"9".to_vec()), + enable_next_expected_msg_seq_num: false, + persist: true, + verify_logout: true, + verify_test_request_id: true, + } + } + + /// Create a simple FIX Reject message for testing (Reject is an admin message that does NOT get gap-filled) + fn create_test_message(seq_num: SeqNum) -> Box { + Box::new(FixtMessage { + header: Box::new(Header { + msg_seq_num: seq_num, + msg_type: MsgType::Reject, + sender_comp_id: FixString::from_ascii_lossy(b"SENDER".to_vec()), + target_comp_id: FixString::from_ascii_lossy(b"TARGET".to_vec()), + sending_time: UtcTimestamp::now(), + begin_string: FixString::from_ascii_lossy(b"FIXT.1.1".to_vec()), + ..Default::default() + }), + body: Box::new(Message::Reject(Reject { + ref_seq_num: seq_num, + ..Default::default() + })), + trailer: Box::new(new_trailer()), + }) + } + + #[tokio::test] + async fn test_resend_range_no_persist_sends_gap_fill() { + // Disable persistence to test gap fill path + let mut session_settings = test_session_settings(); + session_settings.persist = false; + + let storage = InMemoryStorage::new(); + let state = Rc::new(RefCell::new(State::new(storage))); + + // Set next sender msg seq num so there's a range to gap fill + state.borrow_mut().set_next_sender_msg_seq_num(11); + + let (sender_tx, mut sender_rx) = mpsc::unbounded_channel(); + let sender = Sender::new(sender_tx); + + let (emitter, _event_stream) = events_channel(); + + let (disconnect_tx, _disconnect_rx) = tokio::sync::oneshot::channel(); + + let session = Session::new( + test_settings(), + session_settings, + state, + sender, + emitter, + disconnect_tx, + ); + + // Request resend of messages 1-10 + session.resend_range(1, 10).await; + + // Should receive a SequenceReset (gap fill) + let msg = sender_rx.recv().await.expect("should receive message"); + match msg { + crate::SenderMsg::Msg(fixt_msg) => { + assert_eq!(fixt_msg.msg_type(), MsgType::SequenceReset); + if let Message::SequenceReset(seq_reset) = &*fixt_msg.body { + assert_eq!(seq_reset.gap_fill_flag, Some(true)); + } else { + panic!("Expected SequenceReset message"); + } + } + _ => panic!("Expected Msg, got Disconnect"), + } + } + + #[tokio::test] + async fn test_resend_range_with_persist_resends_messages() { + let storage = InMemoryStorage::new(); + let state = Rc::new(RefCell::new(State::new(storage))); + + // Store some messages (note: InMemoryStorage.fetch_range ignores the range + // and returns all stored messages, so we store exactly what we want to resend) + let num_messages = 5; + for seq_num in 1..=num_messages { + let msg = create_test_message(seq_num); + let serialized = msg.serialize(); + state.borrow_mut().store(seq_num, &serialized); + } + // Set next sender seq num + state + .borrow_mut() + .set_next_sender_msg_seq_num((num_messages + 1) as SeqNum); + + let (sender_tx, mut sender_rx) = mpsc::unbounded_channel(); + let sender = Sender::new(sender_tx); + + let (emitter, _event_stream) = events_channel(); + + let (disconnect_tx, _disconnect_rx) = tokio::sync::oneshot::channel(); + + let mut session_settings = test_session_settings(); + session_settings.persist = true; + + let session = Session::new( + test_settings(), + session_settings, + state, + sender, + emitter, + disconnect_tx, + ); + + // Request resend of messages 1-5 + session.resend_range(1, num_messages as SeqNum).await; + + // Should receive resent messages (with PossDupFlag set) + let mut received_count = 0; + while let Ok(msg) = sender_rx.try_recv() { + match msg { + crate::SenderMsg::Msg(fixt_msg) => { + // Resent messages should have PossDupFlag=true + assert_eq!(fixt_msg.header.poss_dup_flag, Some(true)); + received_count += 1; + } + _ => panic!("Expected Msg, got Disconnect"), + } + } + + assert_eq!( + received_count, num_messages, + "Should have resent {} messages", + num_messages + ); + } + + #[tokio::test] + async fn test_resend_range_yields_after_batch() { + // This test verifies that the resend loop yields periodically + // We'll store more than RESEND_BATCH_SIZE (64) messages and verify + // that all messages are processed correctly + + let storage = InMemoryStorage::new(); + let state = Rc::new(RefCell::new(State::new(storage))); + + // Store 150 messages (more than 2 batches of 64) + // Note: InMemoryStorage.fetch_range ignores the range parameter + let num_messages = 150; + for seq_num in 1..=num_messages { + let msg = create_test_message(seq_num as SeqNum); + let serialized = msg.serialize(); + state.borrow_mut().store(seq_num as SeqNum, &serialized); + } + // Set next sender seq num + state + .borrow_mut() + .set_next_sender_msg_seq_num((num_messages + 1) as SeqNum); + + let (sender_tx, mut sender_rx) = mpsc::unbounded_channel(); + let sender = Sender::new(sender_tx); + + let (emitter, _event_stream) = events_channel(); + + let (disconnect_tx, _disconnect_rx) = tokio::sync::oneshot::channel(); + + let mut session_settings = test_session_settings(); + session_settings.persist = true; + + let session = Session::new( + test_settings(), + session_settings, + state, + sender, + emitter, + disconnect_tx, + ); + + // Request resend of all messages + session.resend_range(1, num_messages as SeqNum).await; + + // Count received messages + let mut received_count = 0; + while let Ok(msg) = sender_rx.try_recv() { + match msg { + crate::SenderMsg::Msg(_) => received_count += 1, + _ => {} + } + } + + assert_eq!( + received_count, num_messages, + "Should have resent all {} messages", + num_messages + ); + } +}