diff --git a/ci/vendor-wit.sh b/ci/vendor-wit.sh index e33a98c8bd57..33c14b8730f4 100755 --- a/ci/vendor-wit.sh +++ b/ci/vendor-wit.sh @@ -61,7 +61,7 @@ make_vendor "wasi-http" " make_vendor "wasi-tls" " io@v0.2.3 - tls@v0.2.0-draft+73b0a0f + tls@v0.2.0-draft+d6fbdc7 " make_vendor "wasi-config" "config@f4d699b" diff --git a/crates/test-programs/src/bin/tls_sample_application.rs b/crates/test-programs/src/bin/tls_sample_application.rs index d9e21a88a196..4c7f306a43b6 100644 --- a/crates/test-programs/src/bin/tls_sample_application.rs +++ b/crates/test-programs/src/bin/tls_sample_application.rs @@ -1,66 +1,96 @@ -use anyhow::{Context, Result}; +use anyhow::{anyhow, Context, Result}; use core::str; -use test_programs::wasi::sockets::network::{IpSocketAddress, Network}; +use test_programs::wasi::sockets::network::{IpAddress, IpSocketAddress, Network}; use test_programs::wasi::sockets::tcp::{ShutdownType, TcpSocket}; use test_programs::wasi::tls::types::ClientHandshake; -fn make_tls_request(domain: &str) -> Result { - const PORT: u16 = 443; +const PORT: u16 = 443; +fn test_tls_sample_application(domain: &str, ip: IpAddress) -> Result<()> { let request = format!("GET / HTTP/1.1\r\nHost: {domain}\r\nUser-Agent: wasmtime-wasi-rust\r\n\r\n"); let net = Network::default(); - let Some(ip) = net - .permissive_blocking_resolve_addresses(domain) - .unwrap() - .first() - .map(|a| a.to_owned()) - else { - return Err(anyhow::anyhow!("DNS lookup failed.")); - }; - let socket = TcpSocket::new(ip.family()).unwrap(); let (tcp_input, tcp_output) = socket .blocking_connect(&net, IpSocketAddress::new(ip, PORT)) - .context("failed to connect")?; + .context("tcp connect failed")?; let (client_connection, tls_input, tls_output) = ClientHandshake::new(domain, tcp_input, tcp_output) .blocking_finish() - .map_err(|_| anyhow::anyhow!("failed to finish handshake"))?; + .context("tls handshake failed")?; - tls_output.blocking_write_util(request.as_bytes()).unwrap(); + tls_output + .blocking_write_util(request.as_bytes()) + .context("writing http request failed")?; client_connection .blocking_close_output(&tls_output) - .map_err(|_| anyhow::anyhow!("failed to close tls connection"))?; + .context("closing tls connection failed")?; socket.shutdown(ShutdownType::Send)?; let response = tls_input .blocking_read_to_end() - .map_err(|_| anyhow::anyhow!("failed to read output"))?; - String::from_utf8(response).context("error converting response") + .context("reading http response failed")?; + + if String::from_utf8(response)?.contains("HTTP/1.1 200 OK") { + Ok(()) + } else { + Err(anyhow!("server did not respond with 200 OK")) + } } -fn test_tls_sample_application() { - // since this is testing remote endpoint to ensure system cert store works +/// This test sets up a TCP connection using one domain, and then attempts to +/// perform a TLS handshake using another unrelated domain. This should result +/// in a handshake error. +fn test_tls_invalid_certificate(_domain: &str, ip: IpAddress) -> Result<()> { + const BAD_DOMAIN: &'static str = "wrongdomain.localhost"; + + let net = Network::default(); + + let socket = TcpSocket::new(ip.family()).unwrap(); + let (tcp_input, tcp_output) = socket + .blocking_connect(&net, IpSocketAddress::new(ip, PORT)) + .context("tcp connect failed")?; + + match ClientHandshake::new(BAD_DOMAIN, tcp_input, tcp_output).blocking_finish() { + // We're expecting an error regarding the "certificate" is some form or + // another. When we add more TLS backends other than rustls, this naive + // check will likely need to be revisited/expanded: + Err(e) if e.to_debug_string().contains("certificate") => Ok(()), + + Err(e) => Err(e.into()), + Ok(_) => panic!("expecting server name mismatch"), + } +} + +fn try_live_endpoints(test: impl Fn(&str, IpAddress) -> Result<()>) { + // since this is testing remote endpoints to ensure system cert store works // the test uses a couple different endpoints to reduce the number of flakes const DOMAINS: &'static [&'static str] = &["example.com", "api.github.com"]; + let net = Network::default(); + for &domain in DOMAINS { - match make_tls_request(domain) { - Ok(r) => { - assert!(r.contains("HTTP/1.1 200 OK")); - return; - } + let lookup = net + .permissive_blocking_resolve_addresses(domain) + .unwrap() + .first() + .map(|a| a.to_owned()) + .ok_or_else(|| anyhow!("DNS lookup failed.")); + + match lookup.and_then(|ip| test(&domain, ip)) { + Ok(()) => return, Err(e) => { - eprintln!("Failed to make TLS request to {domain}: {e}"); + eprintln!("test for {domain} failed: {e:#}"); } } } - panic!("All TLS requests failed."); + + panic!("all tests failed"); } fn main() { - test_tls_sample_application(); + try_live_endpoints(test_tls_sample_application); + try_live_endpoints(test_tls_invalid_certificate); } diff --git a/crates/test-programs/src/lib.rs b/crates/test-programs/src/lib.rs index d22cebac88a8..562577f0a4db 100644 --- a/crates/test-programs/src/lib.rs +++ b/crates/test-programs/src/lib.rs @@ -48,3 +48,11 @@ pub mod proxy { }, }); } + +impl std::fmt::Display for wasi::io::error::Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(&self.to_debug_string()) + } +} + +impl std::error::Error for wasi::io::error::Error {} diff --git a/crates/test-programs/src/tls.rs b/crates/test-programs/src/tls.rs index b59cb29948eb..68076aa494df 100644 --- a/crates/test-programs/src/tls.rs +++ b/crates/test-programs/src/tls.rs @@ -1,11 +1,12 @@ use crate::wasi::clocks::monotonic_clock; +use crate::wasi::io::error::Error as IoError; use crate::wasi::io::streams::StreamError; use crate::wasi::tls::types::{ClientConnection, ClientHandshake, InputStream, OutputStream}; const TIMEOUT_NS: u64 = 1_000_000_000; impl ClientHandshake { - pub fn blocking_finish(self) -> Result<(ClientConnection, InputStream, OutputStream), ()> { + pub fn blocking_finish(self) -> Result<(ClientConnection, InputStream, OutputStream), IoError> { let future = ClientHandshake::finish(self); let timeout = monotonic_clock::subscribe_duration(TIMEOUT_NS * 200); let pollable = future.subscribe(); diff --git a/crates/wasi-tls/src/lib.rs b/crates/wasi-tls/src/lib.rs index 8ca591b1636c..583d37c2e8e4 100644 --- a/crates/wasi-tls/src/lib.rs +++ b/crates/wasi-tls/src/lib.rs @@ -71,7 +71,7 @@ #![doc(test(attr(deny(warnings))))] #![doc(test(attr(allow(dead_code, unused_variables, unused_mut))))] -use anyhow::{Context, Result}; +use anyhow::Result; use bytes::Bytes; use rustls::pki_types::ServerName; use std::io; @@ -88,6 +88,7 @@ use wasmtime_wasi::OutputStream; use wasmtime_wasi::{ async_trait, bindings::io::{ + error::Error as HostIoError, poll::Pollable as HostPollable, streams::{InputStream as BoxInputStream, OutputStream as BoxOutputStream}, }, @@ -149,6 +150,57 @@ pub fn add_to_linker( generated::types::add_to_linker_get_host(l, &opts, f)?; Ok(()) } + +enum TlsError { + /// The component should trap. Under normal circumstances, this only occurs + /// when the underlying transport stream returns [`StreamError::Trap`]. + Trap(anyhow::Error), + + /// A failure indicated by the underlying transport stream as + /// [`StreamError::LastOperationFailed`]. + Io(wasmtime_wasi::IoError), + + /// A TLS protocol error occurred. + Tls(rustls::Error), +} + +impl TlsError { + /// Create a [`TlsError::Tls`] error from a simple message. + fn msg(msg: &str) -> Self { + // (Ab)using rustls' error type to synthesize our own TLS errors: + Self::Tls(rustls::Error::General(msg.to_string())) + } +} + +impl From for TlsError { + fn from(error: io::Error) -> Self { + // Report unexpected EOFs as an error to prevent truncation attacks. + // See: https://docs.rs/rustls/latest/rustls/struct.Reader.html#method.read + if let io::ErrorKind::WriteZero | io::ErrorKind::UnexpectedEof = error.kind() { + return Self::msg("underlying transport closed abruptly"); + } + + // Errors from underlying transport. + // These have been wrapped inside `io::Error`s by our wasi-to-tokio stream transformer below. + let error = match error.downcast::() { + Ok(StreamError::LastOperationFailed(e)) => return Self::Io(e), + Ok(StreamError::Trap(e)) => return Self::Trap(e), + Ok(StreamError::Closed) => unreachable!("our wasi-to-tokio stream transformer should have translated this to a 0-sized read"), + Err(e) => e, + }; + + // Errors from `rustls`. + // These have been wrapped inside `io::Error`s by `tokio-rustls`. + let error = match error.downcast::() { + Ok(e) => return Self::Tls(e), + Err(e) => e, + }; + + // All errors should have been handled by the clauses above. + Self::Trap(anyhow::Error::new(error).context("unknown wasi-tls error")) + } +} + /// Represents the ClientHandshake which will be used to configure the handshake pub struct ClientHandShake { server_name: String, @@ -180,16 +232,17 @@ impl<'a> generated::types::HostClientHandshake for WasiTlsCtx<'a> { let handshake = self.table.delete(this)?; let server_name = handshake.server_name; let streams = handshake.streams; - let domain = ServerName::try_from(server_name)?; Ok(self .table .push(FutureStreams(StreamState::Pending(Box::pin(async move { - let connector = tokio_rustls::TlsConnector::from(default_client_config()); - connector + let domain = ServerName::try_from(server_name) + .map_err(|_| TlsError::msg("invalid server name"))?; + + let stream = tokio_rustls::TlsConnector::from(default_client_config()) .connect(domain, streams) - .await - .with_context(|| "connection failed") + .await?; + Ok(stream) }))))?) } @@ -203,7 +256,7 @@ impl<'a> generated::types::HostClientHandshake for WasiTlsCtx<'a> { } /// Future streams provides the tls streams after the handshake is completed -pub struct FutureStreams(StreamState>); +pub struct FutureStreams(StreamState>); /// Library specific version of TLS connection after the handshake is completed. /// This alias allows it to use with wit-bindgen component generator which won't take generic types @@ -239,30 +292,36 @@ impl<'a> generated::types::HostFutureClientStreams for WasiTlsCtx<'a> { Resource, Resource, ), - (), + Resource, >, (), >, >, > { - { - let this = self.table.get(&this)?; - match &this.0 { - StreamState::Pending(_) => return Ok(None), - StreamState::Ready(Ok(_)) => (), - StreamState::Ready(Err(_)) => { - return Ok(Some(Ok(Err(())))); - } - StreamState::Closed => return Ok(Some(Err(()))), - } + let this = &mut self.table.get_mut(&this)?.0; + match this { + StreamState::Pending(_) => return Ok(None), + StreamState::Closed => return Ok(Some(Err(()))), + StreamState::Ready(_) => (), } - let StreamState::Ready(Ok(tls_stream)) = - mem::replace(&mut self.table.get_mut(&this)?.0, StreamState::Closed) - else { + let StreamState::Ready(result) = mem::replace(this, StreamState::Closed) else { unreachable!() }; + let tls_stream = match result { + Ok(s) => s, + Err(TlsError::Trap(e)) => return Err(e), + Err(TlsError::Io(e)) => { + let error = self.table.push(e)?; + return Ok(Some(Ok(Err(error)))); + } + Err(TlsError::Tls(e)) => { + let error = self.table.push(wasmtime_wasi::IoError::new(e))?; + return Ok(Some(Ok(Err(error)))); + } + }; + let (rx, tx) = tokio::io::split(tls_stream); let write_stream = AsyncTlsWriteStream::new(TlsWriter::new(tx)); let client = ClientConnection { @@ -347,15 +406,15 @@ impl AsyncWrite for WasiStreams { return match output.write(Bytes::copy_from_slice(&buf[..count])) { Ok(()) => Poll::Ready(Ok(count)), Err(StreamError::Closed) => Poll::Ready(Ok(0)), - Err(StreamError::LastOperationFailed(e) | StreamError::Trap(e)) => { - Poll::Ready(Err(std::io::Error::other(e))) - } + Err(e) => Poll::Ready(Err(std::io::Error::other(e))), }; } - Err(StreamError::Closed) => return Poll::Ready(Ok(0)), - Err(StreamError::LastOperationFailed(e) | StreamError::Trap(e)) => { - return Poll::Ready(Err(std::io::Error::other(e))) + Err(StreamError::Closed) => { + // Our current version of tokio-rustls does not handle returning `Ok(0)` well. + // See: https://github.com/rustls/tokio-rustls/issues/92 + return Poll::Ready(Err(std::io::ErrorKind::WriteZero.into())); } + Err(e) => return Poll::Ready(Err(std::io::Error::other(e))), }; } } @@ -621,7 +680,8 @@ mod tests { let (tx1, rx1) = oneshot::channel::<()>(); let mut future_streams = FutureStreams(StreamState::Pending(Box::pin(async move { - rx1.await.map_err(|_| anyhow::anyhow!("oneshot canceled")) + rx1.await + .map_err(|_| TlsError::Trap(anyhow::anyhow!("oneshot canceled"))) }))); let mut fut = future_streams.ready(); diff --git a/crates/wasi-tls/wit/deps/tls/types.wit b/crates/wasi-tls/wit/deps/tls/types.wit index ee8a847a2e7b..6c8a48734767 100644 --- a/crates/wasi-tls/wit/deps/tls/types.wit +++ b/crates/wasi-tls/wit/deps/tls/types.wit @@ -4,6 +4,8 @@ interface types { use wasi:io/streams@0.2.3.{input-stream, output-stream}; @unstable(feature = tls) use wasi:io/poll@0.2.3.{pollable}; + @unstable(feature = tls) + use wasi:io/error@0.2.3.{error as io-error}; @unstable(feature = tls) resource client-handshake { @@ -26,6 +28,6 @@ interface types { subscribe: func() -> pollable; @unstable(feature = tls) - get: func() -> option>>>; + get: func() -> option, io-error>>>; } } diff --git a/crates/wasi/src/lib.rs b/crates/wasi/src/lib.rs index 71eac772645c..6c77a27b90f5 100644 --- a/crates/wasi/src/lib.rs +++ b/crates/wasi/src/lib.rs @@ -280,7 +280,8 @@ pub use wasmtime::component::{ResourceTable, ResourceTableError}; // users of this crate depend on them at these names. pub use wasmtime_wasi_io::poll::{subscribe, DynFuture, DynPollable, MakeFuture, Pollable}; pub use wasmtime_wasi_io::streams::{ - DynInputStream, DynOutputStream, InputStream, OutputStream, StreamError, StreamResult, + DynInputStream, DynOutputStream, Error as IoError, InputStream, OutputStream, StreamError, + StreamResult, }; pub use wasmtime_wasi_io::{IoImpl, IoView};