Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ci/vendor-wit.sh
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ make_vendor "wasi-http" "

make_vendor "wasi-tls" "
io@v0.2.3
tls@v0.2.0-draft+73b0a0f
tls@v0.2.0-draft+d6fbdc7
"

make_vendor "wasi-config" "config@f4d699b"
Expand Down
88 changes: 59 additions & 29 deletions crates/test-programs/src/bin/tls_sample_application.rs
Original file line number Diff line number Diff line change
@@ -1,66 +1,96 @@
use anyhow::{Context, Result};
use anyhow::{anyhow, Context, Result};
use core::str;
use test_programs::wasi::sockets::network::{IpSocketAddress, Network};
use test_programs::wasi::sockets::network::{IpAddress, IpSocketAddress, Network};
use test_programs::wasi::sockets::tcp::{ShutdownType, TcpSocket};
use test_programs::wasi::tls::types::ClientHandshake;

fn make_tls_request(domain: &str) -> Result<String> {
const PORT: u16 = 443;
const PORT: u16 = 443;

fn test_tls_sample_application(domain: &str, ip: IpAddress) -> Result<()> {
let request =
format!("GET / HTTP/1.1\r\nHost: {domain}\r\nUser-Agent: wasmtime-wasi-rust\r\n\r\n");

let net = Network::default();

let Some(ip) = net
.permissive_blocking_resolve_addresses(domain)
.unwrap()
.first()
.map(|a| a.to_owned())
else {
return Err(anyhow::anyhow!("DNS lookup failed."));
};

let socket = TcpSocket::new(ip.family()).unwrap();
let (tcp_input, tcp_output) = socket
.blocking_connect(&net, IpSocketAddress::new(ip, PORT))
.context("failed to connect")?;
.context("tcp connect failed")?;

let (client_connection, tls_input, tls_output) =
ClientHandshake::new(domain, tcp_input, tcp_output)
.blocking_finish()
.map_err(|_| anyhow::anyhow!("failed to finish handshake"))?;
.context("tls handshake failed")?;

tls_output.blocking_write_util(request.as_bytes()).unwrap();
tls_output
.blocking_write_util(request.as_bytes())
.context("writing http request failed")?;
client_connection
.blocking_close_output(&tls_output)
.map_err(|_| anyhow::anyhow!("failed to close tls connection"))?;
.context("closing tls connection failed")?;
socket.shutdown(ShutdownType::Send)?;
let response = tls_input
.blocking_read_to_end()
.map_err(|_| anyhow::anyhow!("failed to read output"))?;
String::from_utf8(response).context("error converting response")
.context("reading http response failed")?;

if String::from_utf8(response)?.contains("HTTP/1.1 200 OK") {
Ok(())
} else {
Err(anyhow!("server did not respond with 200 OK"))
}
}

fn test_tls_sample_application() {
// since this is testing remote endpoint to ensure system cert store works
/// This test sets up a TCP connection using one domain, and then attempts to
/// perform a TLS handshake using another unrelated domain. This should result
/// in a handshake error.
fn test_tls_invalid_certificate(_domain: &str, ip: IpAddress) -> Result<()> {
const BAD_DOMAIN: &'static str = "wrongdomain.localhost";

let net = Network::default();

let socket = TcpSocket::new(ip.family()).unwrap();
let (tcp_input, tcp_output) = socket
.blocking_connect(&net, IpSocketAddress::new(ip, PORT))
.context("tcp connect failed")?;

match ClientHandshake::new(BAD_DOMAIN, tcp_input, tcp_output).blocking_finish() {
// We're expecting an error regarding the "certificate" is some form or
// another. When we add more TLS backends other than rustls, this naive
// check will likely need to be revisited/expanded:
Err(e) if e.to_debug_string().contains("certificate") => Ok(()),

Err(e) => Err(e.into()),
Ok(_) => panic!("expecting server name mismatch"),
}
}

fn try_live_endpoints(test: impl Fn(&str, IpAddress) -> Result<()>) {
// since this is testing remote endpoints to ensure system cert store works
// the test uses a couple different endpoints to reduce the number of flakes
const DOMAINS: &'static [&'static str] = &["example.com", "api.github.com"];

let net = Network::default();

for &domain in DOMAINS {
match make_tls_request(domain) {
Ok(r) => {
assert!(r.contains("HTTP/1.1 200 OK"));
return;
}
let lookup = net
.permissive_blocking_resolve_addresses(domain)
.unwrap()
.first()
.map(|a| a.to_owned())
.ok_or_else(|| anyhow!("DNS lookup failed."));

match lookup.and_then(|ip| test(&domain, ip)) {
Ok(()) => return,
Err(e) => {
eprintln!("Failed to make TLS request to {domain}: {e}");
eprintln!("test for {domain} failed: {e:#}");
}
}
}
panic!("All TLS requests failed.");

panic!("all tests failed");
}

fn main() {
test_tls_sample_application();
try_live_endpoints(test_tls_sample_application);
try_live_endpoints(test_tls_invalid_certificate);
}
8 changes: 8 additions & 0 deletions crates/test-programs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,11 @@ pub mod proxy {
},
});
}

impl std::fmt::Display for wasi::io::error::Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.to_debug_string())
}
}

impl std::error::Error for wasi::io::error::Error {}
3 changes: 2 additions & 1 deletion crates/test-programs/src/tls.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use crate::wasi::clocks::monotonic_clock;
use crate::wasi::io::error::Error as IoError;
use crate::wasi::io::streams::StreamError;
use crate::wasi::tls::types::{ClientConnection, ClientHandshake, InputStream, OutputStream};

const TIMEOUT_NS: u64 = 1_000_000_000;

impl ClientHandshake {
pub fn blocking_finish(self) -> Result<(ClientConnection, InputStream, OutputStream), ()> {
pub fn blocking_finish(self) -> Result<(ClientConnection, InputStream, OutputStream), IoError> {
let future = ClientHandshake::finish(self);
let timeout = monotonic_clock::subscribe_duration(TIMEOUT_NS * 200);
let pollable = future.subscribe();
Expand Down
116 changes: 88 additions & 28 deletions crates/wasi-tls/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
#![doc(test(attr(deny(warnings))))]
#![doc(test(attr(allow(dead_code, unused_variables, unused_mut))))]

use anyhow::{Context, Result};
use anyhow::Result;
use bytes::Bytes;
use rustls::pki_types::ServerName;
use std::io;
Expand All @@ -88,6 +88,7 @@ use wasmtime_wasi::OutputStream;
use wasmtime_wasi::{
async_trait,
bindings::io::{
error::Error as HostIoError,
poll::Pollable as HostPollable,
streams::{InputStream as BoxInputStream, OutputStream as BoxOutputStream},
},
Expand Down Expand Up @@ -149,6 +150,57 @@ pub fn add_to_linker<T: Send>(
generated::types::add_to_linker_get_host(l, &opts, f)?;
Ok(())
}

enum TlsError {
/// The component should trap. Under normal circumstances, this only occurs
/// when the underlying transport stream returns [`StreamError::Trap`].
Trap(anyhow::Error),

/// A failure indicated by the underlying transport stream as
/// [`StreamError::LastOperationFailed`].
Io(wasmtime_wasi::IoError),

/// A TLS protocol error occurred.
Tls(rustls::Error),
}

impl TlsError {
/// Create a [`TlsError::Tls`] error from a simple message.
fn msg(msg: &str) -> Self {
// (Ab)using rustls' error type to synthesize our own TLS errors:
Self::Tls(rustls::Error::General(msg.to_string()))
}
}

impl From<io::Error> for TlsError {
fn from(error: io::Error) -> Self {
// Report unexpected EOFs as an error to prevent truncation attacks.
// See: https://docs.rs/rustls/latest/rustls/struct.Reader.html#method.read
if let io::ErrorKind::WriteZero | io::ErrorKind::UnexpectedEof = error.kind() {
return Self::msg("underlying transport closed abruptly");
}

// Errors from underlying transport.
// These have been wrapped inside `io::Error`s by our wasi-to-tokio stream transformer below.
let error = match error.downcast::<StreamError>() {
Ok(StreamError::LastOperationFailed(e)) => return Self::Io(e),
Ok(StreamError::Trap(e)) => return Self::Trap(e),
Ok(StreamError::Closed) => unreachable!("our wasi-to-tokio stream transformer should have translated this to a 0-sized read"),
Err(e) => e,
};

// Errors from `rustls`.
// These have been wrapped inside `io::Error`s by `tokio-rustls`.
let error = match error.downcast::<rustls::Error>() {
Ok(e) => return Self::Tls(e),
Err(e) => e,
};

// All errors should have been handled by the clauses above.
Self::Trap(anyhow::Error::new(error).context("unknown wasi-tls error"))
}
}

/// Represents the ClientHandshake which will be used to configure the handshake
pub struct ClientHandShake {
server_name: String,
Expand Down Expand Up @@ -180,16 +232,17 @@ impl<'a> generated::types::HostClientHandshake for WasiTlsCtx<'a> {
let handshake = self.table.delete(this)?;
let server_name = handshake.server_name;
let streams = handshake.streams;
let domain = ServerName::try_from(server_name)?;

Ok(self
.table
.push(FutureStreams(StreamState::Pending(Box::pin(async move {
let connector = tokio_rustls::TlsConnector::from(default_client_config());
connector
let domain = ServerName::try_from(server_name)
.map_err(|_| TlsError::msg("invalid server name"))?;

let stream = tokio_rustls::TlsConnector::from(default_client_config())
.connect(domain, streams)
.await
.with_context(|| "connection failed")
.await?;
Ok(stream)
}))))?)
}

Expand All @@ -203,7 +256,7 @@ impl<'a> generated::types::HostClientHandshake for WasiTlsCtx<'a> {
}

/// Future streams provides the tls streams after the handshake is completed
pub struct FutureStreams<T>(StreamState<Result<T>>);
pub struct FutureStreams<T>(StreamState<Result<T, TlsError>>);

/// Library specific version of TLS connection after the handshake is completed.
/// This alias allows it to use with wit-bindgen component generator which won't take generic types
Expand Down Expand Up @@ -239,30 +292,36 @@ impl<'a> generated::types::HostFutureClientStreams for WasiTlsCtx<'a> {
Resource<BoxInputStream>,
Resource<BoxOutputStream>,
),
(),
Resource<HostIoError>,
>,
(),
>,
>,
> {
{
let this = self.table.get(&this)?;
match &this.0 {
StreamState::Pending(_) => return Ok(None),
StreamState::Ready(Ok(_)) => (),
StreamState::Ready(Err(_)) => {
return Ok(Some(Ok(Err(()))));
}
StreamState::Closed => return Ok(Some(Err(()))),
}
let this = &mut self.table.get_mut(&this)?.0;
match this {
StreamState::Pending(_) => return Ok(None),
StreamState::Closed => return Ok(Some(Err(()))),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we be return some information about the stream being closed with this error?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StreamState::Ready(_) => (),
}

let StreamState::Ready(Ok(tls_stream)) =
mem::replace(&mut self.table.get_mut(&this)?.0, StreamState::Closed)
else {
let StreamState::Ready(result) = mem::replace(this, StreamState::Closed) else {
unreachable!()
};

let tls_stream = match result {
Ok(s) => s,
Err(TlsError::Trap(e)) => return Err(e),
Err(TlsError::Io(e)) => {
let error = self.table.push(e)?;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need logic to clean up these table entries?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean the table entries for the newly created wasi:io/error resources? The ownership of these resource is transferred into the guest one line below it. It's then the responsibility of the guest to drop it properly like any other resource. When the guest drops that resource, wasmtime removes it from the table in https://github.com/bytecodealliance/wasmtime/blob/main/crates/wasi-io/src/impls.rs#L114-L116

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the explanation

return Ok(Some(Ok(Err(error))));
}
Err(TlsError::Tls(e)) => {
let error = self.table.push(wasmtime_wasi::IoError::new(e))?;
return Ok(Some(Ok(Err(error))));
}
};

let (rx, tx) = tokio::io::split(tls_stream);
let write_stream = AsyncTlsWriteStream::new(TlsWriter::new(tx));
let client = ClientConnection {
Expand Down Expand Up @@ -347,15 +406,15 @@ impl AsyncWrite for WasiStreams {
return match output.write(Bytes::copy_from_slice(&buf[..count])) {
Ok(()) => Poll::Ready(Ok(count)),
Err(StreamError::Closed) => Poll::Ready(Ok(0)),
Err(StreamError::LastOperationFailed(e) | StreamError::Trap(e)) => {
Poll::Ready(Err(std::io::Error::other(e)))
}
Err(e) => Poll::Ready(Err(std::io::Error::other(e))),
};
}
Err(StreamError::Closed) => return Poll::Ready(Ok(0)),
Err(StreamError::LastOperationFailed(e) | StreamError::Trap(e)) => {
return Poll::Ready(Err(std::io::Error::other(e)))
Err(StreamError::Closed) => {
// Our current version of tokio-rustls does not handle returning `Ok(0)` well.
// See: https://github.com/rustls/tokio-rustls/issues/92
return Poll::Ready(Err(std::io::ErrorKind::WriteZero.into()));
}
Err(e) => return Poll::Ready(Err(std::io::Error::other(e))),
};
}
}
Expand Down Expand Up @@ -621,7 +680,8 @@ mod tests {
let (tx1, rx1) = oneshot::channel::<()>();

let mut future_streams = FutureStreams(StreamState::Pending(Box::pin(async move {
rx1.await.map_err(|_| anyhow::anyhow!("oneshot canceled"))
rx1.await
.map_err(|_| TlsError::Trap(anyhow::anyhow!("oneshot canceled")))
})));

let mut fut = future_streams.ready();
Expand Down
4 changes: 3 additions & 1 deletion crates/wasi-tls/wit/deps/tls/types.wit
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ interface types {
use wasi:io/streams@0.2.3.{input-stream, output-stream};
@unstable(feature = tls)
use wasi:io/poll@0.2.3.{pollable};
@unstable(feature = tls)
use wasi:io/error@0.2.3.{error as io-error};

@unstable(feature = tls)
resource client-handshake {
Expand All @@ -26,6 +28,6 @@ interface types {
subscribe: func() -> pollable;

@unstable(feature = tls)
get: func() -> option<result<result<tuple<client-connection, input-stream, output-stream>>>>;
get: func() -> option<result<result<tuple<client-connection, input-stream, output-stream>, io-error>>>;
}
}
3 changes: 2 additions & 1 deletion crates/wasi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,8 @@ pub use wasmtime::component::{ResourceTable, ResourceTableError};
// users of this crate depend on them at these names.
pub use wasmtime_wasi_io::poll::{subscribe, DynFuture, DynPollable, MakeFuture, Pollable};
pub use wasmtime_wasi_io::streams::{
DynInputStream, DynOutputStream, InputStream, OutputStream, StreamError, StreamResult,
DynInputStream, DynOutputStream, Error as IoError, InputStream, OutputStream, StreamError,
StreamResult,
};
pub use wasmtime_wasi_io::{IoImpl, IoView};

Expand Down