Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions crates/flux-network/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ keywords = ["low-latency", "performance", "communication"]
repository = "https://github.com/gattaca-com/flux"

[dependencies]
flux.workspace = true
flux-communication.workspace = true
flux-timing.workspace = true
flux-utils.workspace = true
Expand All @@ -17,5 +18,7 @@ tracing.workspace = true
wincode = { workspace = true, optional = true }

[dev-dependencies]
spine-derive.workspace = true
tempfile.workspace = true
wincode-derive = { workspace = true }
wincode = { workspace = true }
161 changes: 147 additions & 14 deletions crates/flux-network/src/tcp/connector.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use std::{net::SocketAddr, sync::Arc};
use std::net::SocketAddr;

use flux::spine::{SpineProducerWithDCache, SpineProducers};
use flux_timing::{Duration, Nanos, Repeater};
use flux_utils::{DCache, safe_panic};
use flux_utils::{DCachePtr, safe_panic};
use mio::{Events, Interest, Poll, Token, event::Event, net::TcpListener};
use tracing::{debug, error, warn};

use crate::tcp::{ConnState, MessagePayload, TcpStream, TcpTelemetry, stream::set_socket_buf_size};
use crate::tcp::{ConnState, TcpStream, TcpTelemetry, stream::set_socket_buf_size};

#[derive(Clone, Copy, Debug)]
#[repr(u8)]
Expand All @@ -28,8 +29,12 @@ pub enum ConnectionVariant {
Listener(TcpListener),
}

/// Event emitted by [`TcpConnector::poll_with`] for each notable IO occurrence.
pub enum PollEvent<'a> {
/// Event emitted by [`TcpConnector::poll_with`] and
/// [`TcpConnector::poll_with_produce`] for each notable IO occurrence.
///
/// For [`poll_with`]: `Payload = &'a [u8]`.
/// For [`poll_with_produce`]: `Payload = Result<T, E>`.
pub enum PollEvent<Payload> {
/// A new connection was accepted from a listener.
///
/// - `listener`: token of the listening socket that accepted
Expand All @@ -43,7 +48,7 @@ pub enum PollEvent<'a> {
/// A connection was closed (by the remote or due to an IO error).
Disconnect { token: Token },
/// A complete framed message was received.
Message { token: Token, payload: MessagePayload<'a>, send_ts: Nanos },
Message { token: Token, payload: Payload, send_ts: Nanos },
}

struct ConnectionManager {
Expand All @@ -53,7 +58,7 @@ struct ConnectionManager {
on_connect_msg: Option<Vec<u8>>,
telemetry: TcpTelemetry,
socket_buf_size: Option<usize>,
dcache: Option<Arc<DCache>>,
dcache: Option<DCachePtr>,

// Always only outbound/client side connection streams
to_be_reconnected: Vec<(Token, ConnectionVariant)>,
Expand Down Expand Up @@ -314,7 +319,7 @@ impl ConnectionManager {
#[inline]
fn handle_event<F>(&mut self, e: &Event, handler: &mut F)
where
F: for<'a> FnMut(PollEvent<'a>),
F: for<'a> FnMut(PollEvent<&'a [u8]>),
{
let event_token = e.token();
let Some(stream_id) = self.conns.iter().position(|(t, _)| t == &event_token) else {
Expand All @@ -330,8 +335,8 @@ impl ConnectionManager {
self.poll.registry(),
e,
self.dcache.as_deref(),
&mut |token, payload, send_ts| {
handler(PollEvent::Message { token, payload, send_ts });
&mut |token, bytes, send_ts| {
handler(PollEvent::Message { token, payload: bytes, send_ts });
},
) == ConnState::Disconnected
{
Expand Down Expand Up @@ -390,6 +395,97 @@ impl ConnectionManager {
}
}
}

#[inline]
fn handle_event_produce<T, E, G, P, F>(
&mut self,
e: &Event,
parse: &mut G,
produce: &mut P,
handler: &mut F,
) where
T: 'static + Copy,
G: FnMut(Token, &[u8]) -> Result<T, E>,
P: SpineProducers + AsRef<SpineProducerWithDCache<T>>,
F: FnMut(PollEvent<Result<T, E>>),
{
let event_token = e.token();
let Some(stream_id) = self.conns.iter().position(|(t, _)| t == &event_token) else {
safe_panic!("got event for unknown token");
return;
};

loop {
match &mut self.conns[stream_id].1 {
ConnectionVariant::Outbound(tcp_connection) |
ConnectionVariant::Inbound(tcp_connection) => {
let dcache =
self.dcache.as_deref().expect("dcache required for poll_with_produce");
if tcp_connection.poll_with_produce(
self.poll.registry(),
e,
dcache,
parse,
produce,
&mut |token, result, send_ts| {
handler(PollEvent::Message { token, payload: result, send_ts });
},
) == ConnState::Disconnected
{
handler(PollEvent::Disconnect { token: event_token });
self.disconnect_at_index(stream_id);
}
return;
}
ConnectionVariant::Listener(tcp_listener) => {
if let Ok((mut stream, addr)) = tcp_listener.accept() {
tracing::info!(?addr, "client connected");
if let Some(size) = self.socket_buf_size {
set_socket_buf_size(&stream, size);
}
let token = Token(self.next_token);
if let Err(e) =
self.poll.registry().register(&mut stream, token, Interest::READABLE)
{
error!("couldn't register client {e}");
let _ = stream.shutdown(std::net::Shutdown::Both);
continue;
};
if let Err(e) = stream.set_nodelay(true) {
error!("couldn't set nodelay on stream to {addr}: {e}");
continue;
}
let mut conn = TcpStream::from_stream_with_telemetry(
stream,
token,
addr,
self.telemetry,
self.dcache.is_some(),
);
if let Some(msg) = &self.on_connect_msg &&
conn.write_or_enqueue_with(
self.poll.registry(),
|buf: &mut Vec<u8>| {
buf.extend_from_slice(msg);
},
) == ConnState::Disconnected
{
continue;
}
handler(PollEvent::Accept {
listener: event_token,
stream: token,
peer_addr: addr,
});
self.conns.push((token, ConnectionVariant::Inbound(conn)));
self.next_token += 1;
} else {
return;
}
}
}
}
}
}

/// Non-blocking TCP connector/acceptor built on `mio`.
Expand Down Expand Up @@ -456,9 +552,9 @@ impl TcpConnector {
self
}

/// Attaches a dcache writer as the shared receive buffer for all streams.
pub fn with_dcache(mut self, writer: Arc<DCache>) -> Self {
self.conn_mgr.dcache = Some(writer);
/// Attaches a dcache as the shared receive buffer for all streams.
pub fn with_dcache(mut self, dcache: DCachePtr) -> Self {
self.conn_mgr.dcache = Some(dcache);
self
}

Expand Down Expand Up @@ -486,7 +582,7 @@ impl TcpConnector {
#[inline]
pub fn poll_with<F>(&mut self, mut handler: F) -> bool
where
F: for<'a> FnMut(PollEvent<'a>),
F: for<'a> FnMut(PollEvent<&'a [u8]>),
{
self.conn_mgr.maybe_reconnect();
for token in self.conn_mgr.reconnected_to.drain(..) {
Expand All @@ -506,6 +602,43 @@ impl TcpConnector {
o
}

/// Like [`poll_with`] but for dcache-backed streams. For each received
/// message, applies `parse` to the payload bytes; if `Ok`, calls
/// `produce(t, send_ts)`. The handler receives
/// `PollEvent::Message { payload: Result<T, E>, .. }`.
///
/// # Panics
/// Panics if no dcache was configured via [`with_dcache`].
#[inline]
pub fn poll_with_produce<T, E, G, P, F>(
&mut self,
parse: &mut G,
produce: &mut P,
mut handler: F,
) -> bool
where
T: 'static + Copy,
G: FnMut(Token, &[u8]) -> Result<T, E>,
P: SpineProducers + AsRef<SpineProducerWithDCache<T>>,
F: FnMut(PollEvent<Result<T, E>>),
{
self.conn_mgr.maybe_reconnect();
for token in self.conn_mgr.reconnected_to.drain(..) {
handler(PollEvent::Reconnect { token });
}
if let Err(e) = self.conn_mgr.poll.poll(&mut self.events, Some(std::time::Duration::ZERO)) {
safe_panic!("got error polling {e}");
return false;
}
let mut o = false;
for e in self.events.iter() {
o = true;
self.conn_mgr.handle_event_produce(e, parse, produce, &mut handler);
}
self.conn_mgr.flush_backlogs();
o
}

/// Writes immediately or enqueues bytes for later sending.
///
/// `serialise` is called with a mutable send buffer and must return the
Expand Down
2 changes: 1 addition & 1 deletion crates/flux-network/src/tcp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ mod connector;
mod stream;

pub use connector::{PollEvent, SendBehavior, TcpConnector};
pub use stream::{ConnState, MessagePayload, TcpStream, TcpTelemetry};
pub use stream::{ConnState, TcpStream, TcpTelemetry};
71 changes: 64 additions & 7 deletions crates/flux-network/src/tcp/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::{
net::SocketAddr,
};

use flux::spine::{SpineProducerWithDCache, SpineProducers};
use flux_communication::Timer;
use flux_timing::Nanos;
use flux_utils::{DCache, DCacheRef};
Expand All @@ -15,7 +16,7 @@ enum RxBuf {
use mio::{Interest, Registry, Token, event::Event};
use tracing::{debug, warn};

pub enum MessagePayload<'a> {
enum MessagePayload<'a> {
Raw(&'a [u8]),
Cached(DCacheRef),
}
Expand Down Expand Up @@ -198,9 +199,9 @@ impl TcpStream {

/// Poll socket and calls `on_msg` for every fully assembled frame.
///
/// When no DCache is set, `payload` is [`MessagePayload::Raw`] and the
/// slice is only valid for the duration of the callback. When DCache is
/// set, `payload` is [`MessagePayload::Cached`] and the ref may be kept.
/// The byte slice passed to `on_msg` is only valid for the duration of the
/// callback. Use with non-dcache connectors; for dcache use
/// [`poll_with_produce`].
#[inline]
pub fn poll_with<F>(
&mut self,
Expand All @@ -210,13 +211,69 @@ impl TcpStream {
on_msg: &mut F,
) -> ConnState
where
F: for<'a> FnMut(Token, MessagePayload<'a>, Nanos),
F: for<'a> FnMut(Token, &'a [u8], Nanos),
{
if ev.is_readable() {
loop {
match self.read_frame(dcache) {
ReadOutcome::PayloadDone { payload, send_ts } => {
on_msg(ev.token(), payload, send_ts);
ReadOutcome::PayloadDone { payload: MessagePayload::Raw(bytes), send_ts } => {
on_msg(ev.token(), bytes, send_ts);
}
ReadOutcome::PayloadDone { payload: MessagePayload::Cached(_), .. } => {
flux_utils::safe_panic!(
"poll_with called on dcache stream; use poll_with_produce"
);
}
ReadOutcome::WouldBlock => break,
ReadOutcome::Disconnected => return ConnState::Disconnected,
}
}
}

if ev.is_writable() && self.drain_backlog(registry) == ConnState::Disconnected {
return ConnState::Disconnected;
}

ConnState::Alive
}

/// Like [`poll_with`] but for dcache-backed streams. Applies `parse` to
/// each payload's bytes and calls `produce(t, send_ts)` on `Ok`. Use with
/// dcache connectors; for raw use [`poll_with`].
#[inline]
pub fn poll_with_produce<T, E, G, P, F>(
&mut self,
registry: &Registry,
ev: &Event,
dcache: &DCache,
parse: &mut G,
produce: &mut P,
on_msg: &mut F,
) -> ConnState
where
T: 'static + Copy,
G: FnMut(Token, &[u8]) -> Result<T, E>,
P: SpineProducers + AsRef<SpineProducerWithDCache<T>>,
F: FnMut(Token, Result<T, E>, Nanos),
{
if ev.is_readable() {
loop {
match self.read_frame(Some(dcache)) {
ReadOutcome::PayloadDone { payload: MessagePayload::Raw(_), .. } => {
flux_utils::safe_panic!(
"poll_with_produce called on non-dcache stream; use poll_with"
);
}
ReadOutcome::PayloadDone { payload: MessagePayload::Cached(dref), send_ts } => {
match dcache.map(dref, |bytes| parse(self.token, bytes)) {
Ok(result) => {
if let Ok(t) = &result {
produce.produce_with_dref(*t, dref, send_ts);
}
on_msg(ev.token(), result, send_ts);
}
Err(e) => warn!("dcache map failed: {e}"),
}
}
ReadOutcome::WouldBlock => break,
ReadOutcome::Disconnected => return ConnState::Disconnected,
Expand Down
8 changes: 3 additions & 5 deletions crates/flux-network/tests/tcp_broadcast_burst.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{
time::Duration,
};

use flux_network::tcp::{MessagePayload, PollEvent, SendBehavior, TcpConnector};
use flux_network::tcp::{PollEvent, SendBehavior, TcpConnector};

const NUM_RECEIVERS: usize = 4;
const BURST_SIZE: usize = 20;
Expand All @@ -26,10 +26,8 @@ fn spawn_receiver(addr: SocketAddr) -> thread::JoinHandle<Vec<Vec<u8>>> {

while !disconnected && std::time::Instant::now() < deadline {
conn.poll_with(|event| match event {
PollEvent::Message { payload, .. } => {
if let MessagePayload::Raw(bytes) = payload {
frames.push(bytes.to_vec());
}
PollEvent::Message { payload: bytes, .. } => {
frames.push(bytes.to_vec());
}
PollEvent::Disconnect { .. } => {
disconnected = true;
Expand Down
Loading