-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Expose Wasi-TLS handshake error #10429
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,66 +1,96 @@ | ||
| use anyhow::{Context, Result}; | ||
| use anyhow::{anyhow, Context, Result}; | ||
| use core::str; | ||
| use test_programs::wasi::sockets::network::{IpSocketAddress, Network}; | ||
| use test_programs::wasi::sockets::network::{IpAddress, IpSocketAddress, Network}; | ||
| use test_programs::wasi::sockets::tcp::{ShutdownType, TcpSocket}; | ||
| use test_programs::wasi::tls::types::ClientHandshake; | ||
|
|
||
| fn make_tls_request(domain: &str) -> Result<String> { | ||
| const PORT: u16 = 443; | ||
| const PORT: u16 = 443; | ||
|
|
||
| fn test_tls_sample_application(domain: &str, ip: IpAddress) -> Result<()> { | ||
| let request = | ||
| format!("GET / HTTP/1.1\r\nHost: {domain}\r\nUser-Agent: wasmtime-wasi-rust\r\n\r\n"); | ||
|
|
||
| let net = Network::default(); | ||
|
|
||
| let Some(ip) = net | ||
| .permissive_blocking_resolve_addresses(domain) | ||
| .unwrap() | ||
| .first() | ||
| .map(|a| a.to_owned()) | ||
| else { | ||
| return Err(anyhow::anyhow!("DNS lookup failed.")); | ||
| }; | ||
|
|
||
| let socket = TcpSocket::new(ip.family()).unwrap(); | ||
| let (tcp_input, tcp_output) = socket | ||
| .blocking_connect(&net, IpSocketAddress::new(ip, PORT)) | ||
| .context("failed to connect")?; | ||
| .context("tcp connect failed")?; | ||
|
|
||
| let (client_connection, tls_input, tls_output) = | ||
| ClientHandshake::new(domain, tcp_input, tcp_output) | ||
| .blocking_finish() | ||
| .map_err(|_| anyhow::anyhow!("failed to finish handshake"))?; | ||
| .context("tls handshake failed")?; | ||
|
|
||
| tls_output.blocking_write_util(request.as_bytes()).unwrap(); | ||
| tls_output | ||
| .blocking_write_util(request.as_bytes()) | ||
| .context("writing http request failed")?; | ||
| client_connection | ||
| .blocking_close_output(&tls_output) | ||
| .map_err(|_| anyhow::anyhow!("failed to close tls connection"))?; | ||
| .context("closing tls connection failed")?; | ||
| socket.shutdown(ShutdownType::Send)?; | ||
| let response = tls_input | ||
| .blocking_read_to_end() | ||
| .map_err(|_| anyhow::anyhow!("failed to read output"))?; | ||
| String::from_utf8(response).context("error converting response") | ||
| .context("reading http response failed")?; | ||
|
|
||
| if String::from_utf8(response)?.contains("HTTP/1.1 200 OK") { | ||
| Ok(()) | ||
| } else { | ||
| Err(anyhow!("server did not respond with 200 OK")) | ||
| } | ||
| } | ||
|
|
||
| fn test_tls_sample_application() { | ||
| // since this is testing remote endpoint to ensure system cert store works | ||
| /// This test sets up a TCP connection using one domain, and then attempts to | ||
| /// perform a TLS handshake using another unrelated domain. This should result | ||
| /// in a handshake error. | ||
| fn test_tls_invalid_certificate(_domain: &str, ip: IpAddress) -> Result<()> { | ||
| const BAD_DOMAIN: &'static str = "wrongdomain.localhost"; | ||
|
|
||
| let net = Network::default(); | ||
|
|
||
| let socket = TcpSocket::new(ip.family()).unwrap(); | ||
| let (tcp_input, tcp_output) = socket | ||
| .blocking_connect(&net, IpSocketAddress::new(ip, PORT)) | ||
| .context("tcp connect failed")?; | ||
|
|
||
| match ClientHandshake::new(BAD_DOMAIN, tcp_input, tcp_output).blocking_finish() { | ||
| // We're expecting an error regarding the "certificate" is some form or | ||
| // another. When we add more TLS backends other than rustls, this naive | ||
| // check will likely need to be revisited/expanded: | ||
| Err(e) if e.to_debug_string().contains("certificate") => Ok(()), | ||
|
|
||
| Err(e) => Err(e.into()), | ||
| Ok(_) => panic!("expecting server name mismatch"), | ||
| } | ||
| } | ||
|
|
||
| fn try_live_endpoints(test: impl Fn(&str, IpAddress) -> Result<()>) { | ||
| // since this is testing remote endpoints to ensure system cert store works | ||
| // the test uses a couple different endpoints to reduce the number of flakes | ||
| const DOMAINS: &'static [&'static str] = &["example.com", "api.github.com"]; | ||
|
|
||
| let net = Network::default(); | ||
|
|
||
| for &domain in DOMAINS { | ||
| match make_tls_request(domain) { | ||
| Ok(r) => { | ||
| assert!(r.contains("HTTP/1.1 200 OK")); | ||
| return; | ||
| } | ||
| let lookup = net | ||
| .permissive_blocking_resolve_addresses(domain) | ||
| .unwrap() | ||
| .first() | ||
| .map(|a| a.to_owned()) | ||
| .ok_or_else(|| anyhow!("DNS lookup failed.")); | ||
|
|
||
| match lookup.and_then(|ip| test(&domain, ip)) { | ||
| Ok(()) => return, | ||
| Err(e) => { | ||
| eprintln!("Failed to make TLS request to {domain}: {e}"); | ||
| eprintln!("test for {domain} failed: {e:#}"); | ||
| } | ||
| } | ||
| } | ||
| panic!("All TLS requests failed."); | ||
|
|
||
| panic!("all tests failed"); | ||
| } | ||
|
|
||
| fn main() { | ||
| test_tls_sample_application(); | ||
| try_live_endpoints(test_tls_sample_application); | ||
| try_live_endpoints(test_tls_invalid_certificate); | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -71,7 +71,7 @@ | |
| #![doc(test(attr(deny(warnings))))] | ||
| #![doc(test(attr(allow(dead_code, unused_variables, unused_mut))))] | ||
|
|
||
| use anyhow::{Context, Result}; | ||
| use anyhow::Result; | ||
| use bytes::Bytes; | ||
| use rustls::pki_types::ServerName; | ||
| use std::io; | ||
|
|
@@ -88,6 +88,7 @@ use wasmtime_wasi::OutputStream; | |
| use wasmtime_wasi::{ | ||
| async_trait, | ||
| bindings::io::{ | ||
| error::Error as HostIoError, | ||
| poll::Pollable as HostPollable, | ||
| streams::{InputStream as BoxInputStream, OutputStream as BoxOutputStream}, | ||
| }, | ||
|
|
@@ -149,6 +150,57 @@ pub fn add_to_linker<T: Send>( | |
| generated::types::add_to_linker_get_host(l, &opts, f)?; | ||
| Ok(()) | ||
| } | ||
|
|
||
| enum TlsError { | ||
| /// The component should trap. Under normal circumstances, this only occurs | ||
| /// when the underlying transport stream returns [`StreamError::Trap`]. | ||
| Trap(anyhow::Error), | ||
|
|
||
| /// A failure indicated by the underlying transport stream as | ||
| /// [`StreamError::LastOperationFailed`]. | ||
| Io(wasmtime_wasi::IoError), | ||
|
|
||
| /// A TLS protocol error occurred. | ||
| Tls(rustls::Error), | ||
| } | ||
|
|
||
| impl TlsError { | ||
| /// Create a [`TlsError::Tls`] error from a simple message. | ||
| fn msg(msg: &str) -> Self { | ||
| // (Ab)using rustls' error type to synthesize our own TLS errors: | ||
| Self::Tls(rustls::Error::General(msg.to_string())) | ||
| } | ||
| } | ||
|
|
||
| impl From<io::Error> for TlsError { | ||
| fn from(error: io::Error) -> Self { | ||
| // Report unexpected EOFs as an error to prevent truncation attacks. | ||
| // See: https://docs.rs/rustls/latest/rustls/struct.Reader.html#method.read | ||
| if let io::ErrorKind::WriteZero | io::ErrorKind::UnexpectedEof = error.kind() { | ||
| return Self::msg("underlying transport closed abruptly"); | ||
| } | ||
|
|
||
| // Errors from underlying transport. | ||
| // These have been wrapped inside `io::Error`s by our wasi-to-tokio stream transformer below. | ||
| let error = match error.downcast::<StreamError>() { | ||
| Ok(StreamError::LastOperationFailed(e)) => return Self::Io(e), | ||
| Ok(StreamError::Trap(e)) => return Self::Trap(e), | ||
| Ok(StreamError::Closed) => unreachable!("our wasi-to-tokio stream transformer should have translated this to a 0-sized read"), | ||
| Err(e) => e, | ||
| }; | ||
|
|
||
| // Errors from `rustls`. | ||
| // These have been wrapped inside `io::Error`s by `tokio-rustls`. | ||
| let error = match error.downcast::<rustls::Error>() { | ||
| Ok(e) => return Self::Tls(e), | ||
| Err(e) => e, | ||
| }; | ||
|
|
||
| // All errors should have been handled by the clauses above. | ||
| Self::Trap(anyhow::Error::new(error).context("unknown wasi-tls error")) | ||
| } | ||
| } | ||
|
|
||
| /// Represents the ClientHandshake which will be used to configure the handshake | ||
| pub struct ClientHandShake { | ||
| server_name: String, | ||
|
|
@@ -180,16 +232,17 @@ impl<'a> generated::types::HostClientHandshake for WasiTlsCtx<'a> { | |
| let handshake = self.table.delete(this)?; | ||
| let server_name = handshake.server_name; | ||
| let streams = handshake.streams; | ||
| let domain = ServerName::try_from(server_name)?; | ||
|
|
||
| Ok(self | ||
| .table | ||
| .push(FutureStreams(StreamState::Pending(Box::pin(async move { | ||
| let connector = tokio_rustls::TlsConnector::from(default_client_config()); | ||
| connector | ||
| let domain = ServerName::try_from(server_name) | ||
| .map_err(|_| TlsError::msg("invalid server name"))?; | ||
|
|
||
| let stream = tokio_rustls::TlsConnector::from(default_client_config()) | ||
| .connect(domain, streams) | ||
| .await | ||
| .with_context(|| "connection failed") | ||
| .await?; | ||
| Ok(stream) | ||
| }))))?) | ||
| } | ||
|
|
||
|
|
@@ -203,7 +256,7 @@ impl<'a> generated::types::HostClientHandshake for WasiTlsCtx<'a> { | |
| } | ||
|
|
||
| /// Future streams provides the tls streams after the handshake is completed | ||
| pub struct FutureStreams<T>(StreamState<Result<T>>); | ||
| pub struct FutureStreams<T>(StreamState<Result<T, TlsError>>); | ||
|
|
||
| /// Library specific version of TLS connection after the handshake is completed. | ||
| /// This alias allows it to use with wit-bindgen component generator which won't take generic types | ||
|
|
@@ -239,30 +292,36 @@ impl<'a> generated::types::HostFutureClientStreams for WasiTlsCtx<'a> { | |
| Resource<BoxInputStream>, | ||
| Resource<BoxOutputStream>, | ||
| ), | ||
| (), | ||
| Resource<HostIoError>, | ||
| >, | ||
| (), | ||
| >, | ||
| >, | ||
| > { | ||
| { | ||
| let this = self.table.get(&this)?; | ||
| match &this.0 { | ||
| StreamState::Pending(_) => return Ok(None), | ||
| StreamState::Ready(Ok(_)) => (), | ||
| StreamState::Ready(Err(_)) => { | ||
| return Ok(Some(Ok(Err(())))); | ||
| } | ||
| StreamState::Closed => return Ok(Some(Err(()))), | ||
| } | ||
| let this = &mut self.table.get_mut(&this)?.0; | ||
| match this { | ||
| StreamState::Pending(_) => return Ok(None), | ||
| StreamState::Closed => return Ok(Some(Err(()))), | ||
| StreamState::Ready(_) => (), | ||
| } | ||
|
|
||
| let StreamState::Ready(Ok(tls_stream)) = | ||
| mem::replace(&mut self.table.get_mut(&this)?.0, StreamState::Closed) | ||
| else { | ||
| let StreamState::Ready(result) = mem::replace(this, StreamState::Closed) else { | ||
| unreachable!() | ||
| }; | ||
|
|
||
| let tls_stream = match result { | ||
| Ok(s) => s, | ||
| Err(TlsError::Trap(e)) => return Err(e), | ||
| Err(TlsError::Io(e)) => { | ||
| let error = self.table.push(e)?; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need logic to clean up these table entries?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You mean the table entries for the newly created
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks for the explanation |
||
| return Ok(Some(Ok(Err(error)))); | ||
| } | ||
| Err(TlsError::Tls(e)) => { | ||
| let error = self.table.push(wasmtime_wasi::IoError::new(e))?; | ||
| return Ok(Some(Ok(Err(error)))); | ||
| } | ||
| }; | ||
|
|
||
| let (rx, tx) = tokio::io::split(tls_stream); | ||
| let write_stream = AsyncTlsWriteStream::new(TlsWriter::new(tx)); | ||
| let client = ClientConnection { | ||
|
|
@@ -347,15 +406,15 @@ impl AsyncWrite for WasiStreams { | |
| return match output.write(Bytes::copy_from_slice(&buf[..count])) { | ||
| Ok(()) => Poll::Ready(Ok(count)), | ||
| Err(StreamError::Closed) => Poll::Ready(Ok(0)), | ||
| Err(StreamError::LastOperationFailed(e) | StreamError::Trap(e)) => { | ||
| Poll::Ready(Err(std::io::Error::other(e))) | ||
| } | ||
| Err(e) => Poll::Ready(Err(std::io::Error::other(e))), | ||
| }; | ||
| } | ||
| Err(StreamError::Closed) => return Poll::Ready(Ok(0)), | ||
| Err(StreamError::LastOperationFailed(e) | StreamError::Trap(e)) => { | ||
| return Poll::Ready(Err(std::io::Error::other(e))) | ||
| Err(StreamError::Closed) => { | ||
| // Our current version of tokio-rustls does not handle returning `Ok(0)` well. | ||
| // See: https://github.com/rustls/tokio-rustls/issues/92 | ||
| return Poll::Ready(Err(std::io::ErrorKind::WriteZero.into())); | ||
| } | ||
| Err(e) => return Poll::Ready(Err(std::io::Error::other(e))), | ||
| }; | ||
| } | ||
| } | ||
|
|
@@ -621,7 +680,8 @@ mod tests { | |
| let (tx1, rx1) = oneshot::channel::<()>(); | ||
|
|
||
| let mut future_streams = FutureStreams(StreamState::Pending(Box::pin(async move { | ||
| rx1.await.map_err(|_| anyhow::anyhow!("oneshot canceled")) | ||
| rx1.await | ||
| .map_err(|_| TlsError::Trap(anyhow::anyhow!("oneshot canceled"))) | ||
| }))); | ||
|
|
||
| let mut fut = future_streams.ready(); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we be return some information about the stream being closed with this error?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See WebAssembly/wasi-tls#10