diff --git a/tokio-stream/src/stream_ext/collect.rs b/tokio-stream/src/stream_ext/collect.rs index f4ca2dfad9d..03815a1c830 100644 --- a/tokio-stream/src/stream_ext/collect.rs +++ b/tokio-stream/src/stream_ext/collect.rs @@ -6,7 +6,8 @@ use core::mem; use core::pin::Pin; use core::task::{ready, Context, Poll}; use pin_project_lite::pin_project; -use std::collections::BTreeSet; +use std::collections::{BTreeMap, BTreeSet, BinaryHeap, HashMap, HashSet, LinkedList, VecDeque}; +use std::hash::Hash; // Do not export this struct until `FromStream` can be unsealed. pin_project! { @@ -136,6 +137,44 @@ impl sealed::FromStreamPriv for Vec { } } +impl FromStream for VecDeque {} + +impl sealed::FromStreamPriv for VecDeque { + type InternalCollection = VecDeque; + + fn initialize(_: sealed::Internal, lower: usize, _upper: Option) -> VecDeque { + VecDeque::with_capacity(lower) + } + + fn extend(_: sealed::Internal, collection: &mut VecDeque, item: T) -> bool { + collection.push_back(item); + true + } + + fn finalize(_: sealed::Internal, collection: &mut VecDeque) -> VecDeque { + mem::take(collection) + } +} + +impl FromStream for LinkedList {} + +impl sealed::FromStreamPriv for LinkedList { + type InternalCollection = LinkedList; + + fn initialize(_: sealed::Internal, _lower: usize, _upper: Option) -> LinkedList { + LinkedList::new() + } + + fn extend(_: sealed::Internal, collection: &mut LinkedList, item: T) -> bool { + collection.push_back(item); + true + } + + fn finalize(_: sealed::Internal, collection: &mut LinkedList) -> LinkedList { + mem::take(collection) + } +} + impl FromStream for BTreeSet {} impl sealed::FromStreamPriv for BTreeSet { @@ -155,6 +194,82 @@ impl sealed::FromStreamPriv for BTreeSet { } } +impl FromStream<(K, V)> for BTreeMap {} + +impl sealed::FromStreamPriv<(K, V)> for BTreeMap { + type InternalCollection = BTreeMap; + + fn initialize(_: sealed::Internal, _lower: usize, _upper: Option) -> BTreeMap { + BTreeMap::new() + } + + fn extend(_: sealed::Internal, collection: &mut BTreeMap, (key, value): (K, V)) -> bool { + collection.insert(key, value); + true + } + + fn finalize(_: sealed::Internal, collection: &mut BTreeMap) -> BTreeMap { + mem::take(collection) + } +} + +impl FromStream for HashSet {} + +impl sealed::FromStreamPriv for HashSet { + type InternalCollection = HashSet; + + fn initialize(_: sealed::Internal, lower: usize, _upper: Option) -> HashSet { + HashSet::with_capacity(lower) + } + + fn extend(_: sealed::Internal, collection: &mut HashSet, item: T) -> bool { + collection.insert(item); + true + } + + fn finalize(_: sealed::Internal, collection: &mut HashSet) -> HashSet { + mem::take(collection) + } +} + +impl FromStream<(K, V)> for HashMap {} + +impl sealed::FromStreamPriv<(K, V)> for HashMap { + type InternalCollection = HashMap; + + fn initialize(_: sealed::Internal, lower: usize, _upper: Option) -> HashMap { + HashMap::with_capacity(lower) + } + + fn extend(_: sealed::Internal, collection: &mut HashMap, (key, value): (K, V)) -> bool { + collection.insert(key, value); + true + } + + fn finalize(_: sealed::Internal, collection: &mut HashMap) -> HashMap { + mem::take(collection) + } +} + +impl FromStream for BinaryHeap {} + +impl sealed::FromStreamPriv for BinaryHeap { + type InternalCollection = BinaryHeap; + + fn initialize(_: sealed::Internal, lower: usize, _upper: Option) -> BinaryHeap { + BinaryHeap::with_capacity(lower) + } + + fn extend(_: sealed::Internal, collection: &mut BinaryHeap, item: T) -> bool { + collection.push(item); + true + } + + fn finalize(_: sealed::Internal, collection: &mut BinaryHeap) -> BinaryHeap { + mem::take(collection) + } +} + impl FromStream for Box<[T]> {} impl sealed::FromStreamPriv for Box<[T]> { diff --git a/tokio-stream/tests/stream_collect.rs b/tokio-stream/tests/stream_collect.rs index d7c6b423488..0cd05e339bb 100644 --- a/tokio-stream/tests/stream_collect.rs +++ b/tokio-stream/tests/stream_collect.rs @@ -1,4 +1,4 @@ -use std::collections::BTreeSet; +use std::collections::{BTreeMap, BTreeSet, BinaryHeap, HashMap, HashSet, LinkedList, VecDeque}; use tokio_stream::{self as stream, StreamExt}; use tokio_test::{assert_pending, assert_ready, assert_ready_err, assert_ready_ok, task}; @@ -63,6 +63,48 @@ async fn collect_vec_items() { assert_eq!(vec![1, 2], coll); } +#[tokio::test] +async fn collect_vecdeque_items() { + let (tx, rx) = mpsc::unbounded_channel_stream(); + let mut fut = task::spawn(rx.collect::>()); + + assert_pending!(fut.poll()); + + tx.send(1).unwrap(); + assert!(fut.is_woken()); + assert_pending!(fut.poll()); + + tx.send(2).unwrap(); + assert!(fut.is_woken()); + assert_pending!(fut.poll()); + + drop(tx); + assert!(fut.is_woken()); + let coll = assert_ready!(fut.poll()); + assert_eq!(VecDeque::from([1, 2]), coll); +} + +#[tokio::test] +async fn collect_linkedlist_items() { + let (tx, rx) = mpsc::unbounded_channel_stream(); + let mut fut = task::spawn(rx.collect::>()); + + assert_pending!(fut.poll()); + + tx.send(1).unwrap(); + assert!(fut.is_woken()); + assert_pending!(fut.poll()); + + tx.send(2).unwrap(); + assert!(fut.is_woken()); + assert_pending!(fut.poll()); + + drop(tx); + assert!(fut.is_woken()); + let coll = assert_ready!(fut.poll()); + assert_eq!(LinkedList::from([1, 2]), coll); +} + #[tokio::test] async fn collect_btreeset_items() { let (tx, rx) = mpsc::unbounded_channel_stream(); @@ -84,6 +126,90 @@ async fn collect_btreeset_items() { assert_eq!(BTreeSet::from([1, 2]), coll); } +#[tokio::test] +async fn collect_btreemap_items() { + let (tx, rx) = mpsc::unbounded_channel_stream(); + let mut fut = task::spawn(rx.collect::>()); + + assert_pending!(fut.poll()); + + tx.send((3, 4)).unwrap(); + assert!(fut.is_woken()); + assert_pending!(fut.poll()); + + tx.send((1, 2)).unwrap(); + assert!(fut.is_woken()); + assert_pending!(fut.poll()); + + drop(tx); + assert!(fut.is_woken()); + let coll = assert_ready!(fut.poll()); + assert_eq!(BTreeMap::from([(1, 2), (3, 4)]), coll); +} + +#[tokio::test] +async fn collect_hashset_items() { + let (tx, rx) = mpsc::unbounded_channel_stream(); + let mut fut = task::spawn(rx.collect::>()); + + assert_pending!(fut.poll()); + + tx.send(1).unwrap(); + assert!(fut.is_woken()); + assert_pending!(fut.poll()); + + tx.send(2).unwrap(); + assert!(fut.is_woken()); + assert_pending!(fut.poll()); + + drop(tx); + assert!(fut.is_woken()); + let coll = assert_ready!(fut.poll()); + assert_eq!(HashSet::from([1, 2]), coll); +} + +#[tokio::test] +async fn collect_hashmap_items() { + let (tx, rx) = mpsc::unbounded_channel_stream(); + let mut fut = task::spawn(rx.collect::>()); + + assert_pending!(fut.poll()); + + tx.send((1, 2)).unwrap(); + assert!(fut.is_woken()); + assert_pending!(fut.poll()); + + tx.send((3, 4)).unwrap(); + assert!(fut.is_woken()); + assert_pending!(fut.poll()); + + drop(tx); + assert!(fut.is_woken()); + let coll = assert_ready!(fut.poll()); + assert_eq!(HashMap::from([(1, 2), (3, 4)]), coll); +} + +#[tokio::test] +async fn collect_binaryheap_items() { + let (tx, rx) = mpsc::unbounded_channel_stream(); + let mut fut = task::spawn(rx.collect::>()); + + assert_pending!(fut.poll()); + + tx.send(2).unwrap(); + assert!(fut.is_woken()); + assert_pending!(fut.poll()); + + tx.send(1).unwrap(); + assert!(fut.is_woken()); + assert_pending!(fut.poll()); + + drop(tx); + assert!(fut.is_woken()); + let coll = assert_ready!(fut.poll()); + assert_eq!(vec![1, 2], coll.into_sorted_vec()); +} + #[tokio::test] async fn collect_string_items() { let (tx, rx) = mpsc::unbounded_channel_stream();