diff --git a/rtc-datachannel/src/data_channel/mod.rs b/rtc-datachannel/src/data_channel/mod.rs index 1411abd1..cf9f2873 100644 --- a/rtc-datachannel/src/data_channel/mod.rs +++ b/rtc-datachannel/src/data_channel/mod.rs @@ -71,23 +71,21 @@ impl DataChannel { ) -> Result { let mut data_channel = DataChannel::new(config.clone(), association_handle, stream_id); - if !config.negotiated { - let msg = Message::DataChannelOpen(DataChannelOpen { - channel_type: config.channel_type, - priority: config.priority, - reliability_parameter: config.reliability_parameter, - label: config.label.bytes().collect(), - protocol: config.protocol.bytes().collect(), - }) - .marshal()?; - - data_channel.write_outs.push_back(DataChannelMessage { - association_handle, - stream_id, - ppi: PayloadProtocolIdentifier::Dcep, - payload: msg, - }); - } + let msg = Message::DataChannelOpen(DataChannelOpen { + channel_type: config.channel_type, + priority: config.priority, + reliability_parameter: config.reliability_parameter, + label: config.label.bytes().collect(), + protocol: config.protocol.bytes().collect(), + }) + .marshal()?; + + data_channel.write_outs.push_back(DataChannelMessage { + association_handle, + stream_id, + ppi: PayloadProtocolIdentifier::Dcep, + payload: msg, + }); Ok(data_channel) } diff --git a/rtc-sctp/src/association/mod.rs b/rtc-sctp/src/association/mod.rs index dc7e1bc8..2180a1bf 100644 --- a/rtc-sctp/src/association/mod.rs +++ b/rtc-sctp/src/association/mod.rs @@ -26,13 +26,14 @@ use crate::param::{ use crate::queue::{payload_queue::PayloadQueue, pending_queue::PendingQueue}; use crate::shared::{AssociationEventInner, AssociationId, EndpointEvent, EndpointEventInner}; use crate::util::{sna16lt, sna32gt, sna32gte, sna32lt, sna32lte}; -use crate::{AssociationEvent, Payload, Side}; +use crate::{AssociationEvent, FlushIds, Payload, Side}; use shared::error::{Error, Result}; use shared::{TransportContext, TransportMessage, TransportProtocol}; use stream::{ReliabilityType, Stream, StreamEvent, StreamId, StreamState}; use timer::{ACK_INTERVAL, RtoManager, Timer, TimerTable}; use crate::association::stream::RecvSendState; +use crate::queue::pending_queue::{FlushEntry, QueueEntry}; use bytes::Bytes; use log::{debug, error, trace, warn}; use rand::random; @@ -429,6 +430,22 @@ impl Association { /// - a call was made to `handle_timeout` #[must_use] pub fn poll_transmit(&mut self, now: Instant) -> Option> { + + // first, see if the next queue entry is a flush signal + if let Some(ids) = self.pop_pending_flush() { + trace!("polled flush({})", ids.flush_id); + return Some(TransportMessage { + now, + transport: TransportContext { + local_addr: self.local_addr, + peer_addr: self.remote_addr, + ecn: None, + transport_protocol: Default::default(), + }, + message: Payload::Flush(ids), + }); + } + let (contents, _) = self.gather_outbound(now); if contents.is_empty() { None @@ -2375,6 +2392,24 @@ impl Association { self.bundle_data_chunks_into_packets(chunks) } + fn pop_pending_flush(&mut self) -> Option { + + // if the first queue entry is a flush signal, pop it off + if let Some(QueueEntry::Flush(e)) = self.pending_queue.peek() { + let unordered = e.unordered; + match self.pending_queue.pop(true, unordered) { + Some(QueueEntry::Flush(e)) => Some(e.ids), + _ => None + } + } else { + None + } + + // TODO: is popping off the pending queue enough to guarantee all the previous messages + // have been written to the final output queue? + // TODO: pop multiple consecutive flush signals? + } + /// pop_pending_data_chunks_to_send pops chunks from the pending queues as many as /// the cwnd and rwnd allows to send. fn pop_pending_data_chunks_to_send( @@ -2392,7 +2427,7 @@ impl Association { // is 0), the data sender can always have one DATA chunk in flight to // the receiver if allowed by cwnd (see rule B, below). - while let Some(c) = self.pending_queue.peek() { + while let Some(QueueEntry::Payload(c)) = self.pending_queue.peek() { let (beginning_fragment, unordered, data_len, stream_identifier) = ( c.beginning_fragment, c.unordered, @@ -2434,7 +2469,7 @@ impl Association { // the data sender can always have one DATA chunk in flight to the receiver if chunks.is_empty() && self.inflight_queue.is_empty() { // Send zero window probe - if let Some(c) = self.pending_queue.peek() { + if let Some(QueueEntry::Payload(c)) = self.pending_queue.peek() { let (beginning_fragment, unordered) = (c.beginning_fragment, c.unordered); if let Some(chunk) = self.move_pending_data_chunk_to_inflight_queue( @@ -2608,7 +2643,7 @@ impl Association { unordered: bool, now: Instant, ) -> Option { - if let Some(mut c) = self.pending_queue.pop(beginning_fragment, unordered) { + if let Some(QueueEntry::Payload(mut c)) = self.pending_queue.pop(beginning_fragment, unordered) { // Mark all fragements are in-flight now if c.ending_fragment { c.set_all_inflight(); @@ -2665,7 +2700,7 @@ impl Association { ..Default::default() }; - self.pending_queue.push(c); + self.pending_queue.push(QueueEntry::Payload(c)); self.awake_write_loop(); Ok(()) @@ -2680,9 +2715,22 @@ impl Association { // Push the chunks into the pending queue first. for c in chunks { - self.pending_queue.push(c); + self.pending_queue.push(QueueEntry::Payload(c)); + } + + self.awake_write_loop(); + Ok(()) + } + + pub(crate) fn send_flush(&mut self, ids: FlushIds, unordered: bool) -> Result<()> { + + let state = self.state(); + if state != AssociationState::Established { + return Err(Error::ErrPayloadDataStateNotExist); } + self.pending_queue.push(QueueEntry::Flush(FlushEntry { ids, unordered })); + self.awake_write_loop(); Ok(()) } diff --git a/rtc-sctp/src/association/stream.rs b/rtc-sctp/src/association/stream.rs index 434910cc..c7410724 100644 --- a/rtc-sctp/src/association/stream.rs +++ b/rtc-sctp/src/association/stream.rs @@ -2,7 +2,7 @@ use crate::association::Association; use crate::association::state::AssociationState; use crate::chunk::chunk_payload_data::{ChunkPayloadData, PayloadProtocolIdentifier}; use crate::queue::reassembly_queue::{Chunks, ReassemblyQueue}; -use crate::{ErrorCauseCode, Event, Side}; +use crate::{ErrorCauseCode, Event, FlushIds, Side}; use shared::error::{Error, Result}; use crate::util::{ByteSlice, BytesArray, BytesSource}; @@ -204,6 +204,22 @@ impl Stream<'_> { } } + /// Pushes a flush signal into the stream, which can be collected later by polling + /// the connection after all previous messages are processed. + /// The flush signal is not sent to the remote peer. + pub fn flush(&mut self, ids: FlushIds) -> Result<()> { + + if !self.is_writable() { + return Err(Error::ErrStreamClosed); + } + + let Some(s) = self.association.streams.get_mut(&self.stream_identifier) + else { return Err(Error::ErrStreamClosed); }; + let unordered = s.unordered; + + self.association.send_flush(ids, unordered) + } + pub fn is_readable(&self) -> bool { if let Some(s) = self.association.streams.get(&self.stream_identifier) { s.state == RecvSendState::Readable || s.state == RecvSendState::ReadWritable diff --git a/rtc-sctp/src/lib.rs b/rtc-sctp/src/lib.rs index 4b0a821f..b7c55fa7 100644 --- a/rtc-sctp/src/lib.rs +++ b/rtc-sctp/src/lib.rs @@ -105,4 +105,14 @@ use crate::packet::PartialDecode; pub enum Payload { PartialDecode(PartialDecode), RawEncode(Vec), + Flush(FlushIds) +} + + +#[derive(Debug, Clone)] +pub struct FlushIds { + pub flush_id: i64, + pub data_channel_id: u16, + pub association_handle: usize, + pub stream_id: u16 } diff --git a/rtc-sctp/src/queue/pending_queue.rs b/rtc-sctp/src/queue/pending_queue.rs index 39e96360..6bcd0d23 100644 --- a/rtc-sctp/src/queue/pending_queue.rs +++ b/rtc-sctp/src/queue/pending_queue.rs @@ -1,9 +1,10 @@ use crate::chunk::chunk_payload_data::ChunkPayloadData; +use crate::FlushIds; use std::collections::VecDeque; /// pendingBaseQueue -pub(crate) type PendingBaseQueue = VecDeque; +pub(crate) type PendingBaseQueue = VecDeque; /// pendingQueue #[derive(Debug, Default)] @@ -21,17 +22,17 @@ impl PendingQueue { PendingQueue::default() } - pub(crate) fn push(&mut self, c: ChunkPayloadData) { - self.n_bytes += c.user_data.len(); - if c.unordered { - self.unordered_queue.push_back(c); + pub(crate) fn push(&mut self, e: QueueEntry) { + self.n_bytes += e.len(); + if e.unordered() { + self.unordered_queue.push_back(e); } else { - self.ordered_queue.push_back(c); + self.ordered_queue.push_back(e); } self.queue_len += 1; } - pub(crate) fn peek(&self) -> Option<&ChunkPayloadData> { + pub(crate) fn peek(&self) -> Option<&QueueEntry> { if self.selected { if self.unordered_is_selected { return self.unordered_queue.front(); @@ -40,10 +41,10 @@ impl PendingQueue { } } - let c = self.unordered_queue.front(); + let e = self.unordered_queue.front(); - if c.is_some() { - return c; + if e.is_some() { + return e; } self.ordered_queue.front() @@ -53,15 +54,15 @@ impl PendingQueue { &mut self, beginning_fragment: bool, unordered: bool, - ) -> Option { + ) -> Option { let popped = if self.selected { let popped = if self.unordered_is_selected { self.unordered_queue.pop_front() } else { self.ordered_queue.pop_front() }; - if let Some(p) = &popped - && p.ending_fragment + if let Some(e) = &popped + && e.ending_fragment() == Some(true) { self.selected = false; } @@ -72,8 +73,8 @@ impl PendingQueue { } if unordered { let popped = { self.unordered_queue.pop_front() }; - if let Some(p) = &popped - && !p.ending_fragment + if let Some(e) = &popped + && e.ending_fragment() == Some(false) { self.selected = true; self.unordered_is_selected = true; @@ -81,8 +82,8 @@ impl PendingQueue { popped } else { let popped = { self.ordered_queue.pop_front() }; - if let Some(p) = &popped - && !p.ending_fragment + if let Some(e) = &popped + && e.ending_fragment() == Some(false) { self.selected = true; self.unordered_is_selected = false; @@ -91,8 +92,8 @@ impl PendingQueue { } }; - if let Some(p) = &popped { - self.n_bytes -= p.user_data.len(); + if let Some(e) = &popped { + self.n_bytes -= e.len(); self.queue_len -= 1; } @@ -111,3 +112,55 @@ impl PendingQueue { self.len() == 0 } } + + +#[derive(Debug)] +pub(crate) struct FlushEntry { + pub(crate) ids: FlushIds, + pub(crate) unordered: bool +} + +/// A queue entry can either be a chunk payload, or a flush signal +#[derive(Debug)] +pub(crate) enum QueueEntry { + Payload(ChunkPayloadData), + Flush(FlushEntry) +} + +impl QueueEntry { + + fn len(&self) -> usize { + match self { + Self::Payload(data) => data.user_data.len(), + Self::Flush(_) => 0 + } + } + + fn unordered(&self) -> bool { + match self { + Self::Payload(data) => data.unordered, + Self::Flush(flush) => flush.unordered + } + } + + fn ending_fragment(&self) -> Option { + match self { + Self::Payload(data) => Some(data.ending_fragment), + Self::Flush(_) => None + } + } + + pub fn as_payload(&self) -> &ChunkPayloadData { + match self { + Self::Payload(data) => data, + Self::Flush(_) => panic!("Expected QueueEntry::Payload, but was QueueEntry::Flush instead") + } + } + + pub fn into_payload(self) -> ChunkPayloadData { + match self { + Self::Payload(data) => data, + Self::Flush(_) => panic!("Expected QueueEntry::Payload, but was QueueEntry::Flush instead") + } + } +} diff --git a/rtc-sctp/src/queue/queue_test.rs b/rtc-sctp/src/queue/queue_test.rs index 05cd53a7..19ff875f 100644 --- a/rtc-sctp/src/queue/queue_test.rs +++ b/rtc-sctp/src/queue/queue_test.rs @@ -198,7 +198,7 @@ const FRAG_BEGIN: usize = 1; const FRAG_MIDDLE: usize = 2; const FRAG_END: usize = 3; -fn make_data_chunk(tsn: u32, unordered: bool, frag: usize) -> ChunkPayloadData { +fn make_data_chunk(tsn: u32, unordered: bool, frag: usize) -> QueueEntry { let mut b = false; let mut e = false; @@ -214,7 +214,7 @@ fn make_data_chunk(tsn: u32, unordered: bool, frag: usize) -> ChunkPayloadData { _ => {} }; - ChunkPayloadData { + QueueEntry::Payload(ChunkPayloadData { tsn, unordered, beginning_fragment: b, @@ -225,7 +225,7 @@ fn make_data_chunk(tsn: u32, unordered: bool, frag: usize) -> ChunkPayloadData { b.freeze() }, ..Default::default() - } + }) } #[test] @@ -236,13 +236,13 @@ fn test_pending_base_queue_push_and_pop() -> Result<()> { pq.push_back(make_data_chunk(2, false, NO_FRAGMENT)); for i in 0..3 { - let c = pq.get(i); + let c = pq.get(i).map(QueueEntry::as_payload); assert!(c.is_some(), "should not be none"); assert_eq!(i as u32, c.unwrap().tsn, "TSN should match"); } for i in 0..3 { - let c = pq.pop_front(); + let c = pq.pop_front().map(QueueEntry::into_payload); assert!(c.is_some(), "should not be none"); assert_eq!(i, c.unwrap().tsn, "TSN should match"); } @@ -251,7 +251,7 @@ fn test_pending_base_queue_push_and_pop() -> Result<()> { pq.push_back(make_data_chunk(4, false, NO_FRAGMENT)); for i in 3..5 { - let c = pq.pop_front(); + let c = pq.pop_front().map(QueueEntry::into_payload); assert!(c.is_some(), "should not be none"); assert_eq!(i, c.unwrap().tsn, "TSN should match"); } @@ -283,7 +283,7 @@ fn test_pending_queue_push_and_pop() -> Result<()> { assert_eq!(30, pq.get_num_bytes(), "total bytes mismatch"); for i in 0..3 { - let c = pq.peek(); + let c = pq.peek().map(QueueEntry::as_payload); assert!(c.is_some(), "peek error"); let c = c.unwrap(); assert_eq!(i, c.tsn, "TSN should match"); @@ -301,7 +301,7 @@ fn test_pending_queue_push_and_pop() -> Result<()> { assert_eq!(20, pq.get_num_bytes(), "total bytes mismatch"); for i in 3..5 { - let c = pq.peek(); + let c = pq.peek().map(QueueEntry::as_payload); assert!(c.is_some(), "peek error"); let c = c.unwrap(); assert_eq!(i, c.tsn, "TSN should match"); @@ -329,7 +329,7 @@ fn test_pending_queue_unordered_wins() -> Result<()> { pq.push(make_data_chunk(3, true, NO_FRAGMENT)); assert_eq!(40, pq.get_num_bytes(), "total bytes mismatch"); - let c = pq.peek(); + let c = pq.peek().map(QueueEntry::as_payload); assert!(c.is_some(), "peek error"); let c = c.unwrap(); assert_eq!(1, c.tsn, "TSN should match"); @@ -337,7 +337,7 @@ fn test_pending_queue_unordered_wins() -> Result<()> { let result = pq.pop(beginning_fragment, unordered); assert!(result.is_some(), "should not error"); - let c = pq.peek(); + let c = pq.peek().map(QueueEntry::as_payload); assert!(c.is_some(), "peek error"); let c = c.unwrap(); assert_eq!(3, c.tsn, "TSN should match"); @@ -345,7 +345,7 @@ fn test_pending_queue_unordered_wins() -> Result<()> { let result = pq.pop(beginning_fragment, unordered); assert!(result.is_some(), "should not error"); - let c = pq.peek(); + let c = pq.peek().map(QueueEntry::as_payload); assert!(c.is_some(), "peek error"); let c = c.unwrap(); assert_eq!(0, c.tsn, "TSN should match"); @@ -353,7 +353,7 @@ fn test_pending_queue_unordered_wins() -> Result<()> { let result = pq.pop(beginning_fragment, unordered); assert!(result.is_some(), "should not error"); - let c = pq.peek(); + let c = pq.peek().map(QueueEntry::as_payload); assert!(c.is_some(), "peek error"); let c = c.unwrap(); assert_eq!(2, c.tsn, "TSN should match"); @@ -379,7 +379,7 @@ fn test_pending_queue_fragments() -> Result<()> { let expects = vec![3, 4, 5, 0, 1, 2]; for exp in expects { - let c = pq.peek(); + let c = pq.peek().map(QueueEntry::as_payload); assert!(c.is_some(), "peek error"); let c = c.unwrap(); assert_eq!(exp, c.tsn, "TSN should match"); @@ -398,7 +398,7 @@ fn test_pending_queue_selection_persistence() -> Result<()> { let mut pq = PendingQueue::new(); pq.push(make_data_chunk(0, false, FRAG_BEGIN)); - let c = pq.peek(); + let c = pq.peek().map(QueueEntry::as_payload); assert!(c.is_some(), "peek error"); let c = c.unwrap(); assert_eq!(0, c.tsn, "TSN should match"); @@ -413,7 +413,7 @@ fn test_pending_queue_selection_persistence() -> Result<()> { let expects = vec![2, 3, 1]; for exp in expects { - let c = pq.peek(); + let c = pq.peek().map(QueueEntry::as_payload); assert!(c.is_some(), "peek error"); let c = c.unwrap(); assert_eq!(exp, c.tsn, "TSN should match"); diff --git a/rtc/src/data_channel/mod.rs b/rtc/src/data_channel/mod.rs index 5b362a07..c8e8a2d4 100644 --- a/rtc/src/data_channel/mod.rs +++ b/rtc/src/data_channel/mod.rs @@ -34,7 +34,7 @@ //! * [RFC 8831 - WebRTC Data Channels](https://www.rfc-editor.org/rfc/rfc8831.html) //! * [RFC 8832 - WebRTC Data Channel Establishment Protocol](https://www.rfc-editor.org/rfc/rfc8832.html) -use crate::peer_connection::RTCPeerConnection; +use crate::peer_connection::{RTCPeerConnection, FlushId}; use crate::peer_connection::message::RTCMessage; use bytes::BytesMut; use interceptor::{Interceptor, NoopInterceptor}; @@ -273,6 +273,22 @@ where } } + /// `flush` sends a signal that indicates when all previous data channel messages are + /// finished sending. + /// After calling `flush`, a future call to `RTCPeerConnection::poll_flush` will emit a signal + /// with the same `id` indicating that all socket messages corresponding to the previous + /// data channel messages have been delivered via `RTCPeerConnection::poll_write`. + pub fn flush(&mut self, id: i64) -> Result<()> { + if self.peer_connection.data_channels.contains_key(&self.id) { + self.peer_connection.flush(FlushId { + flush_id: id, + data_channel_id: self.id + }) + } else { + Err(Error::ErrDataChannelClosed) + } + } + pub fn close(&mut self) -> Result<()> { if let Some(dc) = self.peer_connection.data_channels.get_mut(&self.id) { if dc.ready_state == RTCDataChannelState::Closed { diff --git a/rtc/src/peer_connection/handler/datachannel.rs b/rtc/src/peer_connection/handler/datachannel.rs index 260bea94..a8e0979f 100644 --- a/rtc/src/peer_connection/handler/datachannel.rs +++ b/rtc/src/peer_connection/handler/datachannel.rs @@ -220,6 +220,28 @@ impl<'a> sansio::Protocol sansio::Protocol { + // pass along the flush message + self.ctx.write_outs.push_back(TaggedRTCMessageInternal { + now: msg.now, + transport: msg.transport, + message: RTCMessageInternal::Flush(message), + }); + } _ => { debug!("drop non-RAW packet {:?}", msg.message); } diff --git a/rtc/src/peer_connection/handler/mod.rs b/rtc/src/peer_connection/handler/mod.rs index f212fca0..1a0e1c31 100644 --- a/rtc/src/peer_connection/handler/mod.rs +++ b/rtc/src/peer_connection/handler/mod.rs @@ -7,7 +7,7 @@ pub(crate) mod interceptor; pub(crate) mod sctp; pub(crate) mod srtp; -use crate::peer_connection::RTCPeerConnection; +use crate::peer_connection::{FlushId, RTCPeerConnection}; use crate::peer_connection::event::RTCPeerConnectionEvent; use crate::peer_connection::event::{RTCEvent, RTCEventInternal}; use crate::peer_connection::handler::datachannel::{DataChannelHandler, DataChannelHandlerContext}; @@ -101,6 +101,7 @@ macro_rules! for_each_handler { }; } + #[derive(Default)] pub(crate) struct PipelineContext { // Handler contexts @@ -117,6 +118,7 @@ pub(crate) struct PipelineContext { pub(crate) read_outs: VecDeque, pub(crate) write_outs: VecDeque, pub(crate) event_outs: VecDeque, + pub(crate) flush_outs: Option>, // Statistics accumulator pub(crate) stats: RTCStatsAccumulator, @@ -306,6 +308,15 @@ where transport: msg.transport, message, }); + } else if let RTCMessageInternal::Flush(message) = msg.message { + // create the flush queue, if needed, on first use + if self.pipeline_context.flush_outs.is_none() { + self.pipeline_context.flush_outs = Some(VecDeque::new()); + } + // push the flush message to the queue + if let Some(flush_outs) = self.pipeline_context.flush_outs.as_mut() { + flush_outs.push_back(message.id); + } } } diff --git a/rtc/src/peer_connection/handler/sctp.rs b/rtc/src/peer_connection/handler/sctp.rs index 8f1d4f41..e83622a2 100644 --- a/rtc/src/peer_connection/handler/sctp.rs +++ b/rtc/src/peer_connection/handler/sctp.rs @@ -1,9 +1,9 @@ use crate::peer_connection::event::RTCEventInternal; use crate::peer_connection::event::RTCPeerConnectionEvent; use crate::peer_connection::event::data_channel_event::RTCDataChannelEvent; -use crate::peer_connection::handler::DEFAULT_TIMEOUT_DURATION; +use crate::peer_connection::handler::{FlushId, DEFAULT_TIMEOUT_DURATION}; use crate::peer_connection::message::internal::{ - DTLSMessage, RTCMessageInternal, TaggedRTCMessageInternal, + DTLSMessage, FlushMessage, RTCMessageInternal, TaggedRTCMessageInternal }; use crate::peer_connection::transport::sctp::RTCSctpTransport; use bytes::BytesMut; @@ -13,7 +13,7 @@ use datachannel::message::message_channel_threshold::DataChannelThreshold; use log::{debug, warn}; use sctp::{ AssociationEvent, AssociationHandle, ClientConfig, DatagramEvent, EndpointEvent, Event, - Payload, PayloadProtocolIdentifier, StreamEvent, + FlushIds, Payload, PayloadProtocolIdentifier, StreamEvent }; use shared::error::{Error, Result}; use shared::marshal::Unmarshal; @@ -223,17 +223,7 @@ impl<'a> sansio::Protocol { - if let Payload::RawEncode(raw_data) = transmit.message { - for raw in raw_data { - self.ctx.write_outs.push_back(TaggedRTCMessageInternal { - now: transmit.now, - transport: transmit.transport, - message: RTCMessageInternal::Dtls(DTLSMessage::Raw( - BytesMut::from(&raw[..]), - )), - }); - } - } + write_transmit(transmit, &mut self.ctx.write_outs); } } } @@ -338,18 +328,39 @@ impl<'a> sansio::Protocol sansio::Protocol sansio::Protocol) -> Vec> { + let mut transmits: Vec> = Vec::new(); - if let Payload::RawEncode(contents) = transmit.message { - for content in contents { + + match transmit.message { + + Payload::RawEncode(contents) => { + for content in contents { + transmits.push(TransportMessage { + now: transmit.now, + transport: transmit.transport, + message: Payload::RawEncode(vec![content]), + }) + } + } + + // pass through flush messages intact + Payload::Flush(ids) => { transmits.push(TransportMessage { now: transmit.now, transport: transmit.transport, - message: Payload::RawEncode(vec![content]), + message: Payload::Flush(ids), }) } + + _ => () } transmits } + +fn write_transmit(transmit: TransportMessage, write_outs: &mut VecDeque) { + match transmit.message { + + Payload::RawEncode(raw_data) => { + for raw in raw_data { + write_outs.push_back(TaggedRTCMessageInternal { + now: transmit.now, + transport: transmit.transport, + message: RTCMessageInternal::Dtls(DTLSMessage::Raw(BytesMut::from( + &raw[..], + ))), + }); + } + } + + // pass the flush message along to the write out queue + Payload::Flush(ids) => { + debug!("flush({}) completed for channel {}", ids.flush_id, ids.data_channel_id); + write_outs.push_back(TaggedRTCMessageInternal { + now: transmit.now, + transport: transmit.transport, + message: RTCMessageInternal::Flush(FlushMessage { + id: FlushId { + flush_id: ids.flush_id, + data_channel_id: ids.data_channel_id + }, + association_handle: ids.association_handle, + stream_id: ids.stream_id + }) + }); + } + + _ => {} // drop all other messages + } +} diff --git a/rtc/src/peer_connection/message/internal.rs b/rtc/src/peer_connection/message/internal.rs index d8eee476..97c133aa 100644 --- a/rtc/src/peer_connection/message/internal.rs +++ b/rtc/src/peer_connection/message/internal.rs @@ -1,6 +1,7 @@ use crate::data_channel::RTCDataChannelId; use crate::data_channel::message::RTCDataChannelMessage; use crate::media_stream::track::MediaStreamTrackId; +use crate::peer_connection::FlushId; use bytes::BytesMut; use datachannel::data_channel::DataChannelMessage; use interceptor::Packet; @@ -46,12 +47,31 @@ pub(crate) enum RTPMessage { TrackPacket(TrackPacket), } +#[derive(Debug, Clone)] +pub(crate) struct FlushMessage { + pub(crate) id: FlushId, + pub(crate) association_handle: usize, + pub(crate) stream_id: u16 +} + +impl FlushMessage { + pub(crate) fn new(id: FlushId) -> Self { + Self { + id, + // we will gather the values below later in the processing chain + association_handle: 0, + stream_id: 0 + } + } +} + #[derive(Debug, Clone)] pub(crate) enum RTCMessageInternal { Raw(BytesMut), Stun(STUNMessage), Dtls(DTLSMessage), Rtp(RTPMessage), + Flush(FlushMessage) } impl RTCMessageInternal { @@ -93,6 +113,7 @@ impl RTCMessageInternal { } }, }, + RTCMessageInternal::Flush(_) => 0 } } } diff --git a/rtc/src/peer_connection/mod.rs b/rtc/src/peer_connection/mod.rs index dca5c468..98dac2d9 100644 --- a/rtc/src/peer_connection/mod.rs +++ b/rtc/src/peer_connection/mod.rs @@ -310,6 +310,7 @@ use shared::error::{Error, Result}; use shared::util::math_rand_alpha; use std::collections::HashMap; use std::time::Instant; +use crate::peer_connection::message::internal::{FlushMessage, RTCMessageInternal, TaggedRTCMessageInternal}; /// Builder for creating RTCPeerConnection instances. /// @@ -2234,4 +2235,35 @@ where .stats .snapshot_with_selector(now, selector) } + + /// `flush` sends a signal that indicates when all previous messages on the given data channel + /// are finished sending. + /// After calling `flush`, a future call to `poll_flush` will emit a signal + /// with the same `id` indicating that all socket messages corresponding to the previous + /// data channel messages have been delivered via `poll_write`. + pub fn flush(&mut self, id: FlushId) -> std::result::Result<(), Error> { + let mut endpoint_handler = self.get_endpoint_handler(); + use sansio::Protocol; + endpoint_handler.handle_write(TaggedRTCMessageInternal { + now: Instant::now(), + transport: Default::default(), + message: RTCMessageInternal::Flush(FlushMessage::new(id)) + }) + } + + /// `poll_flush` will return the signal given to `flush` after all previous + /// data channel messages have been delivered via `poll_write`. + pub fn poll_flush(&mut self) -> Option { + self.pipeline_context.flush_outs.as_mut() + .and_then(|flush_outs| flush_outs.pop_front()) + } +} + + +/// A unique identifier for a flush signal on a data channel +#[derive(Debug, Clone)] +pub struct FlushId { + /// a caller-chosen value that will be presented again when the flush signal is eventually polled + pub flush_id: i64, + pub data_channel_id: RTCDataChannelId }