diff --git a/crates/test-programs/src/bin/sockets_0_3_tcp_bind.rs b/crates/test-programs/src/bin/sockets_0_3_tcp_bind.rs index 7849494205..a725a48b56 100644 --- a/crates/test-programs/src/bin/sockets_0_3_tcp_bind.rs +++ b/crates/test-programs/src/bin/sockets_0_3_tcp_bind.rs @@ -61,20 +61,25 @@ async fn test_tcp_bind_reuseaddr(ip: IpAddress) { let connect_addr = IpSocketAddress::new(IpAddress::new_loopback(ip.family()), bind_addr.port()); - client.connect(connect_addr).await.unwrap(); - - let mut sock = accept.next().await.unwrap().unwrap(); - assert_eq!(sock.len(), 1); - let sock = sock.pop().unwrap(); - let (mut data_tx, data_rx) = wit_stream::new(); join!( async { - sock.send(data_rx).await.unwrap(); + client.connect(connect_addr).await.unwrap(); }, async { - data_tx.send(vec![0; 10]).await.unwrap(); - drop(data_tx); - } + let mut sock = accept.next().await.unwrap().unwrap(); + assert_eq!(sock.len(), 1); + let sock = sock.pop().unwrap(); + let (mut data_tx, data_rx) = wit_stream::new(); + join!( + async { + sock.send(data_rx).await.unwrap(); + }, + async { + data_tx.send(vec![0; 10]).await.unwrap(); + drop(data_tx); + } + ); + }, ); bind_addr diff --git a/crates/test-programs/src/bin/sockets_0_3_tcp_connect.rs b/crates/test-programs/src/bin/sockets_0_3_tcp_connect.rs index 57eb5529ee..cc394724dd 100644 --- a/crates/test-programs/src/bin/sockets_0_3_tcp_connect.rs +++ b/crates/test-programs/src/bin/sockets_0_3_tcp_connect.rs @@ -1,3 +1,4 @@ +use futures::{join, StreamExt as _}; use test_programs::p3::wasi::sockets::types::{ ErrorCode, IpAddress, IpAddressFamily, IpSocketAddress, TcpSocket, }; @@ -96,12 +97,12 @@ async fn test_tcp_connect_dual_stack() { async fn test_tcp_connect_explicit_bind(family: IpAddressFamily) { let ip = IpAddress::new_loopback(family); - let listener = { + let (listener, mut accept) = { let bind_address = IpSocketAddress::new(ip, 0); let listener = TcpSocket::new(family); listener.bind(bind_address).unwrap(); - listener.listen().unwrap(); - listener + let accept = listener.listen().unwrap(); + (listener, accept) }; let listener_address = listener.local_address().unwrap(); @@ -111,7 +112,14 @@ async fn test_tcp_connect_explicit_bind(family: IpAddressFamily) { client.bind(IpSocketAddress::new(ip, 0)).unwrap(); // Connect should work: - client.connect(listener_address).await.unwrap(); + join!( + async { + client.connect(listener_address).await.unwrap(); + }, + async { + accept.next().await.unwrap().unwrap(); + } + ); } impl test_programs::p3::exports::wasi::cli::run::Guest for Component { diff --git a/crates/test-programs/src/bin/sockets_0_3_tcp_sample_application.rs b/crates/test-programs/src/bin/sockets_0_3_tcp_sample_application.rs index 871fc57cb7..9a093765e3 100644 --- a/crates/test-programs/src/bin/sockets_0_3_tcp_sample_application.rs +++ b/crates/test-programs/src/bin/sockets_0_3_tcp_sample_application.rs @@ -20,63 +20,64 @@ async fn test_tcp_sample_application(family: IpAddressFamily, bind_address: IpSo let addr = listener.local_address().unwrap(); - { - let client = TcpSocket::new(family); - client.connect(addr).await.unwrap(); - let (mut data_tx, data_rx) = wit_stream::new(); - - join!( - async { - client.send(data_rx).await.unwrap(); - }, - async { - data_tx.send(vec![]).await.unwrap(); - data_tx.send(first_message.into()).await.unwrap(); - drop(data_tx); - } - ); - } - - { - let mut sock = accept.next().await.unwrap().unwrap(); - assert_eq!(sock.len(), 1); - let sock = sock.pop().unwrap(); - - let (mut data_rx, fut) = sock.receive(); - let data = data_rx.next().await.unwrap().unwrap(); - - // Check that we sent and received our message! - assert_eq!(data, first_message); // Not guaranteed to work but should work in practice. - fut.await.unwrap().unwrap().unwrap() - } + join!( + async { + let client = TcpSocket::new(family); + client.connect(addr).await.unwrap(); + let (mut data_tx, data_rx) = wit_stream::new(); + join!( + async { + client.send(data_rx).await.unwrap(); + }, + async { + data_tx.send(vec![]).await.unwrap(); + data_tx.send(first_message.into()).await.unwrap(); + drop(data_tx); + } + ); + }, + async { + let mut sock = accept.next().await.unwrap().unwrap(); + assert_eq!(sock.len(), 1); + let sock = sock.pop().unwrap(); + + let (mut data_rx, fut) = sock.receive(); + let data = data_rx.next().await.unwrap().unwrap(); + + // Check that we sent and received our message! + assert_eq!(data, first_message); // Not guaranteed to work but should work in practice. + fut.await.unwrap().unwrap().unwrap() + }, + ); // Another client - { - let client = TcpSocket::new(family); - client.connect(addr).await.unwrap(); - let (mut data_tx, data_rx) = wit_stream::new(); - join!( - async { - client.send(data_rx).await.unwrap(); - }, - async { - data_tx.send(second_message.into()).await.unwrap(); - drop(data_tx); - } - ); - } - - { - let mut sock = accept.next().await.unwrap().unwrap(); - assert_eq!(sock.len(), 1); - let sock = sock.pop().unwrap(); - let (mut data_rx, fut) = sock.receive(); - let data = data_rx.next().await.unwrap().unwrap(); - - // Check that we sent and received our message! - assert_eq!(data, second_message); // Not guaranteed to work but should work in practice. - fut.await.unwrap().unwrap().unwrap() - } + join!( + async { + let client = TcpSocket::new(family); + client.connect(addr).await.unwrap(); + let (mut data_tx, data_rx) = wit_stream::new(); + join!( + async { + client.send(data_rx).await.unwrap(); + }, + async { + data_tx.send(second_message.into()).await.unwrap(); + drop(data_tx); + } + ); + }, + async { + let mut sock = accept.next().await.unwrap().unwrap(); + assert_eq!(sock.len(), 1); + let sock = sock.pop().unwrap(); + let (mut data_rx, fut) = sock.receive(); + let data = data_rx.next().await.unwrap().unwrap(); + + // Check that we sent and received our message! + assert_eq!(data, second_message); // Not guaranteed to work but should work in practice. + fut.await.unwrap().unwrap().unwrap() + } + ); } impl test_programs::p3::exports::wasi::cli::run::Guest for Component { diff --git a/crates/test-programs/src/bin/sockets_0_3_tcp_sockopts.rs b/crates/test-programs/src/bin/sockets_0_3_tcp_sockopts.rs index 456c9c1bc3..f52fd6e866 100644 --- a/crates/test-programs/src/bin/sockets_0_3_tcp_sockopts.rs +++ b/crates/test-programs/src/bin/sockets_0_3_tcp_sockopts.rs @@ -1,4 +1,4 @@ -use futures::StreamExt as _; +use futures::{join, StreamExt as _}; use test_programs::p3::wasi::sockets::types::{ ErrorCode, IpAddress, IpAddressFamily, IpSocketAddress, TcpSocket, }; @@ -132,10 +132,16 @@ async fn test_tcp_sockopt_inheritance(family: IpAddressFamily) { let mut accept = listener.listen().unwrap(); let bound_addr = listener.local_address().unwrap(); let client = TcpSocket::new(family); - client.connect(bound_addr).await.unwrap(); - let mut sock = accept.next().await.unwrap().unwrap(); - assert_eq!(sock.len(), 1); - let sock = sock.pop().unwrap(); + let ((), sock) = join!( + async { + client.connect(bound_addr).await.unwrap(); + }, + async { + let mut sock = accept.next().await.unwrap().unwrap(); + assert_eq!(sock.len(), 1); + sock.pop().unwrap() + } + ); // Verify options on accepted socket: { @@ -194,10 +200,16 @@ async fn test_tcp_sockopt_after_listen(family: IpAddressFamily) { } let client = TcpSocket::new(family); - client.connect(bound_addr).await.unwrap(); - let mut sock = accept.next().await.unwrap().unwrap(); - assert_eq!(sock.len(), 1); - let sock = sock.pop().unwrap(); + let ((), sock) = join!( + async { + client.connect(bound_addr).await.unwrap(); + }, + async { + let mut sock = accept.next().await.unwrap().unwrap(); + assert_eq!(sock.len(), 1); + sock.pop().unwrap() + } + ); // Verify options on accepted socket: { diff --git a/crates/test-programs/src/bin/sockets_0_3_tcp_states.rs b/crates/test-programs/src/bin/sockets_0_3_tcp_states.rs index f2b3828ade..b754b9652f 100644 --- a/crates/test-programs/src/bin/sockets_0_3_tcp_states.rs +++ b/crates/test-programs/src/bin/sockets_0_3_tcp_states.rs @@ -1,3 +1,4 @@ +use futures::{join, StreamExt as _}; use test_programs::p3::wasi::sockets::types::{ ErrorCode, IpAddress, IpAddressFamily, IpSocketAddress, TcpSocket, }; @@ -146,10 +147,17 @@ async fn test_tcp_connected_state_invariants(family: IpAddressFamily) { let bind_address = IpSocketAddress::new(IpAddress::new_loopback(family), 0); let sock_listener = TcpSocket::new(family); sock_listener.bind(bind_address).unwrap(); - sock_listener.listen().unwrap(); + let mut accept = sock_listener.listen().unwrap(); let addr_listener = sock_listener.local_address().unwrap(); let sock = TcpSocket::new(family); - sock.connect(addr_listener).await.unwrap(); + join!( + async { + sock.connect(addr_listener).await.unwrap(); + }, + async { + accept.next().await.unwrap().unwrap(); + } + ); assert_eq!(sock.bind(bind_address), Err(ErrorCode::InvalidState)); assert_eq!( diff --git a/crates/test-programs/src/bin/sockets_0_3_tcp_streams.rs b/crates/test-programs/src/bin/sockets_0_3_tcp_streams.rs index 328975fe11..ce6669b4a2 100644 --- a/crates/test-programs/src/bin/sockets_0_3_tcp_streams.rs +++ b/crates/test-programs/src/bin/sockets_0_3_tcp_streams.rs @@ -2,7 +2,7 @@ use core::future::Future; use futures::{join, SinkExt as _, StreamExt as _, TryStreamExt as _}; use test_programs::p3::wasi::sockets::types::{ - ErrorCode, IpAddress, IpAddressFamily, IpSocketAddress, TcpSocket, + IpAddress, IpAddressFamily, IpSocketAddress, TcpSocket, }; use test_programs::p3::wit_stream; @@ -13,17 +13,6 @@ test_programs::p3::export!(Component); /// InputStream::read should return `StreamError::Closed` after the connection has been shut down by the server. async fn test_tcp_input_stream_should_be_closed_by_remote_shutdown(family: IpAddressFamily) { setup(family, |server, client| async move { - let (mut server_tx, server_rx) = wit_stream::new(); - join!( - async { - server.send(server_rx).await.unwrap(); - }, - async { - // Shut down the connection from the server side: - server_tx.close().await.unwrap(); - drop(server_tx); - }, - ); drop(server); let (mut client_rx, client_fut) = client.receive(); @@ -32,11 +21,9 @@ async fn test_tcp_input_stream_should_be_closed_by_remote_shutdown(family: IpAdd // Notably, it should _not_ return an empty list (the wasi-io equivalent of EWOULDBLOCK) // See: https://github.com/bytecodealliance/wasmtime/pull/8968 - // TODO: Verify - // Wait for the shutdown signal to reach the client: assert!(client_rx.next().await.is_none()); - assert_eq!(client_fut.await, Some(Ok(Err(ErrorCode::ConnectionReset)))); + assert_eq!(client_fut.await, Some(Ok(Ok(())))); }) .await; } @@ -56,48 +43,37 @@ async fn test_tcp_input_stream_should_be_closed_by_local_shutdown(family: IpAddr server_tx.send(b"Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.".into()).await.unwrap(); drop(server_tx); }, - ); + ); let (client_rx, client_fut) = client.receive(); - // TODO: Verify - // Shut down socket locally: drop(client_rx); // Wait for the shutdown signal to reach the client: - assert_eq!(client_fut.await, Some(Ok(Err(ErrorCode::ConnectionReset)))); + assert_eq!(client_fut.await, Some(Ok(Ok(())))); }).await; } -/// OutputStream should return `StreamError::Closed` after the connection has been locally shut down for sending. +/// StreamWriter should return `StreamError::Closed` after the connection has been locally shut down for sending. async fn test_tcp_output_stream_should_be_closed_by_local_shutdown(family: IpAddressFamily) { - setup(family, |server, client| async move { - let (server_tx, server_rx) = wit_stream::new(); - drop(server_tx); - server.send(server_rx).await.unwrap(); - - let (server_tx, server_rx) = wit_stream::new(); - drop(server_tx); - assert_eq!( - server.send(server_rx).await, - Err(ErrorCode::ConnectionReset) - ); - - let (client_tx, client_rx) = wit_stream::new(); - drop(client_tx); - client.send(client_rx).await.unwrap(); - - let (client_tx, client_rx) = wit_stream::new(); - drop(client_tx); - assert_eq!( - client.send(client_rx).await, - Err(ErrorCode::ConnectionReset) + setup(family, |_server, client| async move { + let (mut client_tx, client_rx) = wit_stream::new(); + client_tx.close().await.unwrap(); + join!( + async { + client.send(client_rx).await.unwrap(); + }, + async { + // TODO: Verify if send on the stream should return an error + //assert!(client_tx.send(b"Hi!".into()).await.is_err()); + drop(client_tx); + } ); }) .await; } -/// Calling `shutdown` while the OutputStream is in the middle of a background write should not cause that write to be lost. +/// Calling `shutdown` while the StreamWriter is in the middle of a background write should not cause that write to be lost. async fn test_tcp_shutdown_should_not_lose_data(family: IpAddressFamily) { setup(family, |server, client| async move { // Minimize the local send buffer: @@ -117,20 +93,19 @@ async fn test_tcp_shutdown_should_not_lose_data(family: IpAddressFamily) { }, async { client_tx.send(outgoing_data.clone()).await.unwrap(); - client_tx.close().await.unwrap(); drop(client_tx); }, + async { + // The peer should receive _all_ data: + let (server_rx, server_fut) = server.receive(); + let incoming_data = server_rx.try_collect::>().await.unwrap().concat(); + assert_eq!( + outgoing_data, incoming_data, + "Received data should match the sent data" + ); + server_fut.await.unwrap().unwrap().unwrap() + }, ); - - // The peer should receive _all_ data: - let (server_rx, server_fut) = server.receive(); - let incoming_data = server_rx.try_collect::>().await.unwrap().concat(); - assert_eq!( - outgoing_data.len(), - incoming_data.len(), - "Received data should match the sent data" - ); - server_fut.await.unwrap().unwrap().unwrap() }) .await; } @@ -165,10 +140,15 @@ async fn setup>( let mut accept = listener.listen().unwrap(); let bound_address = listener.local_address().unwrap(); let client_socket = TcpSocket::new(family); - client_socket.connect(bound_address).await.unwrap(); - let mut accepted_socket = accept.next().await.unwrap().unwrap(); - assert_eq!(accepted_socket.len(), 1); - let accepted_socket = accepted_socket.pop().unwrap(); - + let ((), accepted_socket) = join!( + async { + client_socket.connect(bound_address).await.unwrap(); + }, + async { + let mut accepted_socket = accept.next().await.unwrap().unwrap(); + assert_eq!(accepted_socket.len(), 1); + accepted_socket.pop().unwrap() + }, + ); body(accepted_socket, client_socket).await; } diff --git a/crates/wasi/src/p3/sockets/host/types/tcp.rs b/crates/wasi/src/p3/sockets/host/types/tcp.rs index f9430399eb..1c0764b91b 100644 --- a/crates/wasi/src/p3/sockets/host/types/tcp.rs +++ b/crates/wasi/src/p3/sockets/host/types/tcp.rs @@ -10,7 +10,7 @@ use std::sync::Arc; use anyhow::{ensure, Context as _}; use io_lifetimes::AsSocketlike as _; use rustix::io::Errno; -use tokio::sync::{mpsc, oneshot}; +use tokio::sync::{mpsc, oneshot, Notify}; use wasmtime::component::{ future, stream, Accessor, BackgroundTask, FutureReader, FutureWriter, Lift, Resource, ResourceTable, StreamReader, StreamWriter, @@ -19,7 +19,7 @@ use wasmtime::component::{ use crate::p3::bindings::sockets::types::{ Duration, ErrorCode, HostTcpSocket, IpAddressFamily, IpSocketAddress, TcpSocket, }; -use crate::p3::sockets::tcp::{bind, TcpState}; +use crate::p3::sockets::tcp::{bind, handle_listener, TcpState}; use crate::p3::sockets::util::is_valid_unicast_address; use crate::p3::sockets::{SocketAddrUse, SocketAddressFamily, WasiSocketsImpl, WasiSocketsView}; use crate::runtime::spawn; @@ -169,10 +169,7 @@ impl BackgroundTask for ListenTask { ); } } - TcpState::Connected { - stream: Arc::new(stream), - abort_receive: None, - } + TcpState::connected(stream) } Err(err) => { match Errno::from_io_error(&err) { @@ -220,35 +217,29 @@ impl BackgroundTask for ListenTask { })?; tx = fut.into_future().await; } + store.with(|store| tx.close(store).context("failed to close stream"))?; Ok(()) } } struct ReceiveTask { - stream: Arc, data: StreamWriter, result: FutureWriter>, - abort: oneshot::Receiver<()>, + abort: Arc, + rx: mpsc::Receiver, ErrorCode>>, } impl BackgroundTask for ReceiveTask { - async fn run(self, store: &mut Accessor) -> wasmtime::Result<()> { - let mut abort = pin!(self.abort); + async fn run(mut self, store: &mut Accessor) -> wasmtime::Result<()> { let mut tx = self.data; + let mut abort = pin!(self.abort.notified()); let res = loop { - let mut buf = vec![0; 8096]; - match self.stream.try_read(&mut buf) { - Ok(0) => { - if let Err(err) = self - .stream - .as_socketlike_view::() - .shutdown(Shutdown::Read) - { - break Err(err.into()); - } + match self.rx.recv().await { + None => { + store.with(|store| tx.close(store).context("failed to close stream"))?; + break Ok(()); } - Ok(n) => { - buf.truncate(n); + Some(Ok(buf)) => { let fut = store.with(|store| tx.write(store, buf).context("failed to send chunk"))?; let mut fut = fut.into_future(); @@ -265,23 +256,10 @@ impl BackgroundTask for ReceiveTask { } }; } - Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => { - let mut writable = pin!(self.stream.writable()); - match poll_fn(|cx| match abort.as_mut().poll(cx) { - Poll::Ready(..) => Poll::Ready(None), - Poll::Pending => writable.as_mut().poll(cx).map(Some), - }) - .await - { - Some(Ok(())) => {} - Some(Err(err)) => break Err(err.into()), - None => { - // socket dropped, abort - break Ok(()); - } - } + Some(Err(err)) => { + store.with(|store| tx.close(store).context("failed to close stream"))?; + break Err(err.into()); } - Err(err) => break Err(err.into()), } }; let fut = store.with(|store| { @@ -307,11 +285,9 @@ where fn new(&mut self, address_family: IpAddressFamily) -> wasmtime::Result> { let socket = TcpSocket::new(address_family.into()).context("failed to create socket")?; - let socket = self - .table() + self.table() .push(socket) - .context("failed to push socket resource to table")?; - Ok(socket) + .context("failed to push socket resource to table") } async fn bind( @@ -387,10 +363,7 @@ where ); match res { Ok(stream) => { - socket.tcp_state = TcpState::Connected { - stream: Arc::new(stream), - abort_receive: None, - }; + socket.tcp_state = TcpState::connected(stream); Ok(Ok(())) } Err(err) => { @@ -453,34 +426,7 @@ where listener: Arc::clone(&listener), finished: finished_rx, abort: abort_tx, - task: spawn(async move { - let mut abort = pin!(abort_rx); - loop { - let accept = listener.accept(); - let mut accept = pin!(accept); - let Some(res) = poll_fn(|cx| match abort.as_mut().poll(cx) { - Poll::Ready(..) => Poll::Ready(None), - Poll::Pending => accept.as_mut().poll(cx).map(Some), - }) - .await - else { - break; - }; - let send = task_tx.send(res); - let mut send = pin!(send); - match poll_fn(|cx| match abort.as_mut().poll(cx) { - Poll::Ready(..) => Poll::Ready(None), - Poll::Pending => send.as_mut().poll(cx).map(Some), - }) - .await - { - None | Some(Err(..)) => break, - Some(Ok(())) => {} - } - } - drop(listener); - _ = finished_tx.send(()); - }), + task: spawn(handle_listener(listener, abort_rx, finished_tx, task_tx)), }; Ok(Ok(( rx, @@ -548,13 +494,10 @@ where let mut fut = fut.into_future(); 'outer: loop { let Some((tail, mut buf)) = fut.await else { - match stream + _ = stream .as_socketlike_view::() - .shutdown(Shutdown::Write) - { - Ok(()) => return Ok(Ok(())), - Err(err) => return Ok(Err(err.into())), - } + .shutdown(Shutdown::Write); + return Ok(Ok(())); }; let mut buf = buf.as_mut_slice(); loop { @@ -569,10 +512,18 @@ where } Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => { if let Err(err) = stream.writable().await { + _ = stream + .as_socketlike_view::() + .shutdown(Shutdown::Write); return Ok(Err(err.into())); } } - Err(err) => return Ok(Err(err.into())), + Err(err) => { + _ = stream + .as_socketlike_view::() + .shutdown(Shutdown::Write); + return Ok(Err(err.into())); + } } } } @@ -586,24 +537,19 @@ where let (data_tx, data_rx) = stream(&mut store).context("failed to create stream")?; let (res_tx, res_rx) = future(&mut store).context("failed to create future")?; let sock = get_socket_mut(store.data_mut().table(), &socket)?; - let (abort_tx, abort_rx) = oneshot::channel(); - if let TcpState::Connected { - stream, - abort_receive, - } = &mut sock.tcp_state - { - ensure!( - abort_receive.replace(abort_tx).is_none(), - "`receive` can called at most once" - ); - let stream = Arc::clone(&stream); + if let TcpState::Connected { rx, abort, .. } = &mut sock.tcp_state { + let rx = rx.take().context("`receive` can be called at most once")?; + let abort = Arc::clone(&abort); store.spawn(ReceiveTask { - stream, data: data_tx, result: res_tx, - abort: abort_rx, + abort, + rx, }); } else { + data_tx + .close(&mut store) + .context("failed to close stream")?; let fut = res_tx .write(&mut store, Err(ErrorCode::InvalidState)) .context("failed to write result to future")?; @@ -796,6 +742,19 @@ where ); Ok(()) } + TcpState::Connected { + abort, + finished, + stream, + .. + } => { + abort.notify_waiters(); + // this will unblock only once the task finishes + _ = finished.recv(); + // this must be the only reference to the stream left + ensure!(Arc::into_inner(stream).is_some(), "corrupted stream state"); + Ok(()) + } _ => Ok(()), } } diff --git a/crates/wasi/src/p3/sockets/tcp.rs b/crates/wasi/src/p3/sockets/tcp.rs index 9b2cca2c82..e1287c4238 100644 --- a/crates/wasi/src/p3/sockets/tcp.rs +++ b/crates/wasi/src/p3/sockets/tcp.rs @@ -1,23 +1,110 @@ use core::fmt::Debug; +use core::future::{poll_fn, Future as _}; use core::net::SocketAddr; +use core::pin::pin; +use core::task::Poll; +use std::net::Shutdown; use std::os::fd::{AsFd as _, BorrowedFd}; use std::sync::Arc; use cap_net_ext::AddressFamily; +use io_lifetimes::AsSocketlike as _; use rustix::io::Errno; use rustix::net::sockopt; -use tokio::sync::oneshot; +use tokio::sync::{mpsc, oneshot, Notify}; use crate::p3::bindings::sockets::types::{Duration, ErrorCode, IpAddressFamily, IpSocketAddress}; use crate::p3::sockets::SocketAddressFamily; -use crate::runtime::{with_ambient_tokio_runtime, AbortOnDropJoinHandle}; +use crate::runtime::{spawn, with_ambient_tokio_runtime, AbortOnDropJoinHandle}; use super::util::{normalize_get_buffer_size, normalize_set_buffer_size}; /// Value taken from rust std library. const DEFAULT_BACKLOG: u32 = 128; +pub async fn handle_stream( + stream: Arc, + abort: Arc, + finished: std::sync::mpsc::Sender<()>, + tx: mpsc::Sender, ErrorCode>>, +) { + let mut abort = pin!(abort.notified()); + loop { + let tx = tx.reserve(); + let mut tx = pin!(tx); + let Some(Ok(tx)) = poll_fn(|cx| match abort.as_mut().poll(cx) { + Poll::Ready(..) => Poll::Ready(None), + Poll::Pending => tx.as_mut().poll(cx).map(Some), + }) + .await + else { + break; + }; + let mut buf = vec![0; 8096]; + match stream.try_read(&mut buf) { + Ok(0) => break, + Ok(n) => { + buf.truncate(n); + tx.send(Ok(buf)); + } + Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => { + let mut readable = pin!(stream.readable()); + match poll_fn(|cx| match abort.as_mut().poll(cx) { + Poll::Ready(..) => Poll::Ready(None), + Poll::Pending => readable.as_mut().poll(cx).map(Some), + }) + .await + { + Some(Ok(())) => {} + Some(Err(err)) => tx.send(Err(err.into())), + None => break, + } + } + Err(err) => tx.send(Err(err.into())), + } + } + _ = stream + .as_socketlike_view::() + .shutdown(Shutdown::Read); + drop(stream); + _ = finished.send(()); +} + +pub async fn handle_listener( + listener: Arc, + abort: oneshot::Receiver<()>, + finished: std::sync::mpsc::Sender<()>, + tx: mpsc::Sender>, +) { + let mut abort = pin!(abort); + loop { + let tx = tx.reserve(); + let mut tx = pin!(tx); + let Some(Ok(tx)) = poll_fn(|cx| match abort.as_mut().poll(cx) { + Poll::Ready(..) => Poll::Ready(None), + Poll::Pending => tx.as_mut().poll(cx).map(Some), + }) + .await + else { + break; + }; + let accept = listener.accept(); + let mut accept = pin!(accept); + let Some(res) = poll_fn(|cx| match abort.as_mut().poll(cx) { + Poll::Ready(..) => Poll::Ready(None), + Poll::Pending => accept.as_mut().poll(cx).map(Some), + }) + .await + else { + break; + }; + tx.send(res); + } + drop(listener); + _ = finished.send(()); +} + /// The state of a TCP socket. /// /// This represents the various states a socket can be in during the @@ -32,9 +119,9 @@ pub enum TcpState { /// The socket is now listening and waiting for an incoming connection. Listening { listener: Arc, - task: AbortOnDropJoinHandle<()>, finished: std::sync::mpsc::Receiver<()>, abort: oneshot::Sender<()>, + task: AbortOnDropJoinHandle<()>, }, /// An outgoing connection is started. @@ -43,7 +130,10 @@ pub enum TcpState { /// A connection has been established. Connected { stream: Arc, - abort_receive: Option>, + finished: std::sync::mpsc::Receiver<()>, + abort: Arc, + rx: Option, ErrorCode>>>, + task: AbortOnDropJoinHandle<()>, }, Error(ErrorCode), @@ -65,6 +155,22 @@ impl Debug for TcpState { } } +impl TcpState { + pub fn connected(stream: tokio::net::TcpStream) -> Self { + let stream = Arc::new(stream); + let (task_tx, task_rx) = mpsc::channel(1); + let (finished_tx, finished_rx) = std::sync::mpsc::channel(); + let abort = Arc::default(); + Self::Connected { + stream: Arc::clone(&stream), + finished: finished_rx, + abort: Arc::clone(&abort), + rx: Some(task_rx), + task: spawn(handle_stream(stream, abort, finished_tx, task_tx)), + } + } +} + /// A host TCP socket, plus associated bookkeeping. pub struct TcpSocket { /// The current state in the bind/listen/accept/connect progression. diff --git a/crates/wasi/tests/all/p3/sockets.rs b/crates/wasi/tests/all/p3/sockets.rs index 635c1909b9..aa29f4c258 100644 --- a/crates/wasi/tests/all/p3/sockets.rs +++ b/crates/wasi/tests/all/p3/sockets.rs @@ -18,7 +18,6 @@ async fn sockets_0_3_tcp_connect() -> anyhow::Result<()> { run(SOCKETS_0_3_TCP_CONNECT_COMPONENT).await } -#[ignore = "trap"] #[test_log::test(tokio::test(flavor = "multi_thread"))] async fn sockets_0_3_tcp_sample_application() -> anyhow::Result<()> { run(SOCKETS_0_3_TCP_SAMPLE_APPLICATION_COMPONENT).await @@ -34,7 +33,6 @@ async fn sockets_0_3_tcp_states() -> anyhow::Result<()> { run(SOCKETS_0_3_TCP_STATES_COMPONENT).await } -#[ignore = "deadlock"] #[test_log::test(tokio::test(flavor = "multi_thread"))] async fn sockets_0_3_tcp_streams() -> anyhow::Result<()> { run(SOCKETS_0_3_TCP_STREAMS_COMPONENT).await