diff --git a/tokio-stream/src/stream_ext/collect.rs b/tokio-stream/src/stream_ext/collect.rs index 8140f0b4e4b..f4ca2dfad9d 100644 --- a/tokio-stream/src/stream_ext/collect.rs +++ b/tokio-stream/src/stream_ext/collect.rs @@ -6,6 +6,7 @@ use core::mem; use core::pin::Pin; use core::task::{ready, Context, Poll}; use pin_project_lite::pin_project; +use std::collections::BTreeSet; // Do not export this struct until `FromStream` can be unsealed. pin_project! { @@ -135,6 +136,25 @@ impl sealed::FromStreamPriv for Vec { } } +impl FromStream for BTreeSet {} + +impl sealed::FromStreamPriv for BTreeSet { + type InternalCollection = BTreeSet; + + fn initialize(_: sealed::Internal, _lower: usize, _upper: Option) -> BTreeSet { + BTreeSet::new() + } + + fn extend(_: sealed::Internal, collection: &mut BTreeSet, item: T) -> bool { + collection.insert(item); + true + } + + fn finalize(_: sealed::Internal, collection: &mut BTreeSet) -> BTreeSet { + mem::take(collection) + } +} + impl FromStream for Box<[T]> {} impl sealed::FromStreamPriv for Box<[T]> { diff --git a/tokio-stream/tests/stream_collect.rs b/tokio-stream/tests/stream_collect.rs index 07659a1fc3d..d7c6b423488 100644 --- a/tokio-stream/tests/stream_collect.rs +++ b/tokio-stream/tests/stream_collect.rs @@ -1,3 +1,5 @@ +use std::collections::BTreeSet; + use tokio_stream::{self as stream, StreamExt}; use tokio_test::{assert_pending, assert_ready, assert_ready_err, assert_ready_ok, task}; @@ -61,6 +63,27 @@ async fn collect_vec_items() { assert_eq!(vec![1, 2], coll); } +#[tokio::test] +async fn collect_btreeset_items() { + let (tx, rx) = mpsc::unbounded_channel_stream(); + let mut fut = task::spawn(rx.collect::>()); + + assert_pending!(fut.poll()); + + tx.send(2).unwrap(); + assert!(fut.is_woken()); + assert_pending!(fut.poll()); + + tx.send(1).unwrap(); + assert!(fut.is_woken()); + assert_pending!(fut.poll()); + + drop(tx); + assert!(fut.is_woken()); + let coll = assert_ready!(fut.poll()); + assert_eq!(BTreeSet::from([1, 2]), coll); +} + #[tokio::test] async fn collect_string_items() { let (tx, rx) = mpsc::unbounded_channel_stream();