From 87aefbed3a90b168b6069f54c3a03006f195fdac Mon Sep 17 00:00:00 2001 From: Roman Volosatovs Date: Tue, 16 Dec 2025 15:41:51 +0100 Subject: [PATCH 1/5] feat(p3): implement `wasi:tls` Signed-off-by: Roman Volosatovs --- Cargo.lock | 1 + ci/vendor-wit.sh | 2 + crates/test-programs/artifacts/build.rs | 9 +- ...cation.rs => p2_tls_sample_application.rs} | 0 .../src/bin/p3_tls_sample_application.rs | 168 ++++++++ crates/test-programs/src/p3/mod.rs | 6 +- crates/wasi-tls-nativetls/tests/main.rs | 6 +- crates/wasi-tls/Cargo.toml | 5 + crates/wasi-tls/src/lib.rs | 2 + crates/wasi-tls/src/p3/bindings.rs | 23 ++ crates/wasi-tls/src/p3/host/client.rs | 259 +++++++++++++ crates/wasi-tls/src/p3/host/mod.rs | 364 ++++++++++++++++++ crates/wasi-tls/src/p3/host/server.rs | 284 ++++++++++++++ crates/wasi-tls/src/p3/host/types.rs | 15 + crates/wasi-tls/src/p3/mod.rs | 180 +++++++++ .../wasi-tls/src/p3/wit/deps/tls/client.wit | 38 ++ .../wasi-tls/src/p3/wit/deps/tls/server.wit | 36 ++ crates/wasi-tls/src/p3/wit/deps/tls/types.wit | 5 + crates/wasi-tls/src/p3/wit/deps/tls/world.wit | 7 + crates/wasi-tls/src/p3/wit/world.wit | 2 + crates/wasi-tls/tests/{main.rs => p2.rs} | 6 +- crates/wasi-tls/tests/p3.rs | 82 ++++ 22 files changed, 1489 insertions(+), 11 deletions(-) rename crates/test-programs/src/bin/{tls_sample_application.rs => p2_tls_sample_application.rs} (100%) create mode 100644 crates/test-programs/src/bin/p3_tls_sample_application.rs create mode 100644 crates/wasi-tls/src/p3/bindings.rs create mode 100644 crates/wasi-tls/src/p3/host/client.rs create mode 100644 crates/wasi-tls/src/p3/host/mod.rs create mode 100644 crates/wasi-tls/src/p3/host/server.rs create mode 100644 crates/wasi-tls/src/p3/host/types.rs create mode 100644 crates/wasi-tls/src/p3/mod.rs create mode 100644 crates/wasi-tls/src/p3/wit/deps/tls/client.wit create mode 100644 crates/wasi-tls/src/p3/wit/deps/tls/server.wit create mode 100644 crates/wasi-tls/src/p3/wit/deps/tls/types.wit create mode 100644 crates/wasi-tls/src/p3/wit/deps/tls/world.wit create mode 100644 crates/wasi-tls/src/p3/wit/world.wit rename crates/wasi-tls/tests/{main.rs => p2.rs} (90%) create mode 100644 crates/wasi-tls/tests/p3.rs diff --git a/Cargo.lock b/Cargo.lock index be667f8883d8..1d39a6a67a05 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5189,6 +5189,7 @@ dependencies = [ "test-programs-artifacts", "tokio", "tokio-rustls", + "tracing", "wasmtime", "wasmtime-wasi", "webpki-roots", diff --git a/ci/vendor-wit.sh b/ci/vendor-wit.sh index 150b4ee149bf..18be392500c4 100755 --- a/ci/vendor-wit.sh +++ b/ci/vendor-wit.sh @@ -65,6 +65,8 @@ mkdir -p crates/wasi-tls/wit/deps wkg get --format wit --overwrite "wasi:io@$p2" -o "crates/wasi-tls/wit/deps/io.wit" get_github wasi-tls v0.2.0-draft+505fc98 crates/wasi-tls/wit/deps/tls +get_github wasi-tls v0.3.0-draft crates/wasi-tls/src/p3/wit/deps/tls + rm -rf crates/wasi-config/wit/deps mkdir -p crates/wasi-config/wit/deps get_github wasi-config v0.2.0-rc.1 crates/wasi-config/wit/deps/config diff --git a/crates/test-programs/artifacts/build.rs b/crates/test-programs/artifacts/build.rs index a9d81098b2b9..d6f122e4beeb 100644 --- a/crates/test-programs/artifacts/build.rs +++ b/crates/test-programs/artifacts/build.rs @@ -79,19 +79,20 @@ impl Artifacts { // generates a `foreach_*` macro below. let kind = match test.name.as_str() { s if s.starts_with("p1_") => "p1", - s if s.starts_with("p2_http_") => "p2_http", - s if s.starts_with("p2_cli_") => "p2_cli", s if s.starts_with("p2_api_") => "p2_api", + s if s.starts_with("p2_cli_") => "p2_cli", + s if s.starts_with("p2_http_") => "p2_http", + s if s.starts_with("p2_tls_") => "p2_tls", s if s.starts_with("p2_") => "p2", s if s.starts_with("nn_") => "nn", s if s.starts_with("piped_") => "piped", s if s.starts_with("dwarf_") => "dwarf", s if s.starts_with("config_") => "config", s if s.starts_with("keyvalue_") => "keyvalue", - s if s.starts_with("tls_") => "tls", s if s.starts_with("async_") => "async", - s if s.starts_with("p3_http_") => "p3_http", s if s.starts_with("p3_api_") => "p3_api", + s if s.starts_with("p3_http_") => "p3_http", + s if s.starts_with("p3_tls_") => "p3_tls", s if s.starts_with("p3_") => "p3", s if s.starts_with("fuzz_") => "fuzz", // If you're reading this because you hit this panic, either add diff --git a/crates/test-programs/src/bin/tls_sample_application.rs b/crates/test-programs/src/bin/p2_tls_sample_application.rs similarity index 100% rename from crates/test-programs/src/bin/tls_sample_application.rs rename to crates/test-programs/src/bin/p2_tls_sample_application.rs diff --git a/crates/test-programs/src/bin/p3_tls_sample_application.rs b/crates/test-programs/src/bin/p3_tls_sample_application.rs new file mode 100644 index 000000000000..4ecfbe4e71a4 --- /dev/null +++ b/crates/test-programs/src/bin/p3_tls_sample_application.rs @@ -0,0 +1,168 @@ +use anyhow::{Context as _, Result, anyhow, bail}; +use core::future::{Future as _, poll_fn}; +use core::pin::pin; +use core::str; +use core::task::{Poll, ready}; +use futures::try_join; +use test_programs::p3::wasi::sockets::ip_name_lookup::resolve_addresses; +use test_programs::p3::wasi::sockets::types::{IpAddress, IpSocketAddress, TcpSocket}; +use test_programs::p3::wasi::tls; +use test_programs::p3::wasi::tls::client::Hello; +use test_programs::p3::wit_stream; +use wit_bindgen::StreamResult; + +struct Component; + +test_programs::p3::export!(Component); + +const PORT: u16 = 443; + +async fn test_tls_sample_application(domain: &str, ip: IpAddress) -> Result<()> { + let request = format!( + "GET / HTTP/1.1\r\nHost: {domain}\r\nUser-Agent: wasmtime-wasi-rust\r\nConnection: close\r\n\r\n" + ); + + let sock = TcpSocket::create(ip.family()).unwrap(); + sock.connect(IpSocketAddress::new(ip, PORT)) + .await + .context("tcp connect failed")?; + + let (sock_rx, sock_rx_fut) = sock.receive(); + let hello = Hello::new(); + hello + .set_server_name(domain) + .map_err(|()| anyhow!("failed to set SNI"))?; + let (sock_tx, conn) = tls::client::connect(hello, sock_rx); + let sock_tx_fut = sock.send(sock_tx); + + let mut conn = pin!(conn.into_future()); + let mut sock_rx_fut = pin!(sock_rx_fut.into_future()); + let mut sock_tx_fut = pin!(sock_tx_fut); + let conn = poll_fn(|cx| match conn.as_mut().poll(cx) { + Poll::Ready(Ok(conn)) => Poll::Ready(Ok(conn)), + Poll::Ready(Err(())) => Poll::Ready(Err(anyhow!("tls handshake failed"))), + Poll::Pending => match sock_tx_fut.as_mut().poll(cx) { + Poll::Ready(Ok(())) => Poll::Ready(Err(anyhow!("Tx stream closed unexpectedly"))), + Poll::Ready(Err(err)) => { + Poll::Ready(Err(anyhow!("Tx stream closed with error: {err:?}"))) + } + Poll::Pending => match ready!(sock_rx_fut.as_mut().poll(cx)) { + Ok(_) => Poll::Ready(Err(anyhow!("Rx stream closed unexpectedly"))), + Err(err) => Poll::Ready(Err(anyhow!("Rx stream closed with error: {err:?}"))), + }, + }, + }) + .await?; + + let (mut req_tx, req_rx) = wit_stream::new(); + let (mut res_rx, result_fut) = tls::client::Handshake::finish(conn, req_rx); + + let res = Vec::with_capacity(8192); + try_join!( + async { + let buf = req_tx.write_all(request.into()).await; + assert_eq!(buf, []); + drop(req_tx); + Ok(()) + }, + async { + let (result, buf) = res_rx.read(res).await; + match result { + StreamResult::Complete(..) => { + drop(res_rx); + let res = String::from_utf8(buf)?; + if res.contains("HTTP/1.1 200 OK") { + Ok(()) + } else { + bail!("server did not respond with 200 OK: {res}") + } + } + StreamResult::Dropped => bail!("read dropped"), + StreamResult::Cancelled => bail!("read cancelled"), + } + }, + async { result_fut.await.map_err(|()| anyhow!("TLS session failed")) }, + async { sock_rx_fut.await.context("TCP receipt failed") }, + async { sock_tx_fut.await.context("TCP transmit failed") }, + )?; + Ok(()) +} + +/// This test sets up a TCP connection using one domain, and then attempts to +/// perform a TLS handshake using another unrelated domain. This should result +/// in a handshake error. +async fn test_tls_invalid_certificate(_domain: &str, ip: IpAddress) -> Result<()> { + const BAD_DOMAIN: &'static str = "wrongdomain.localhost"; + + let sock = TcpSocket::create(ip.family()).unwrap(); + sock.connect(IpSocketAddress::new(ip, PORT)) + .await + .context("tcp connect failed")?; + + let (sock_rx, sock_rx_fut) = sock.receive(); + let hello = Hello::new(); + hello + .set_server_name(BAD_DOMAIN) + .map_err(|()| anyhow!("failed to set SNI"))?; + let (sock_tx, conn) = tls::client::connect(hello, sock_rx); + let sock_tx_fut = sock.send(sock_tx); + + try_join!( + async { + match conn.await { + Err(()) => Ok(()), + Ok(_) => panic!("expecting server name mismatch"), + } + }, + async { sock_rx_fut.await.context("TCP receipt failed") }, + async { sock_tx_fut.await.context("TCP transmit failed") }, + )?; + Ok(()) +} + +async fn try_live_endpoints<'a, Fut>(test: impl Fn(&'a str, IpAddress) -> Fut) +where + Fut: Future> + 'a, +{ + // since this is testing remote endpoints to ensure system cert store works + // the test uses a couple different endpoints to reduce the number of flakes + const DOMAINS: &'static [&'static str] = &[ + "example.com", + "api.github.com", + "docs.wasmtime.dev", + "bytecodealliance.org", + "www.rust-lang.org", + ]; + + for &domain in DOMAINS { + let result = (|| async { + let ip = resolve_addresses(domain.into()) + .await? + .first() + .map(|a| a.to_owned()) + .ok_or_else(|| anyhow!("DNS lookup failed."))?; + test(&domain, ip).await + })(); + + match result.await { + Ok(()) => return, + Err(e) => { + eprintln!("test for {domain} failed: {e:#}"); + } + } + } + + panic!("all tests failed"); +} + +impl test_programs::p3::exports::wasi::cli::run::Guest for Component { + async fn run() -> Result<(), ()> { + println!("sample app"); + try_live_endpoints(test_tls_sample_application).await; + println!("invalid cert"); + try_live_endpoints(test_tls_invalid_certificate).await; + Ok(()) + } +} + +fn main() {} diff --git a/crates/test-programs/src/p3/mod.rs b/crates/test-programs/src/p3/mod.rs index d5dea4f18439..7d2eb951ed71 100644 --- a/crates/test-programs/src/p3/mod.rs +++ b/crates/test-programs/src/p3/mod.rs @@ -7,6 +7,7 @@ wit_bindgen::generate!({ world testp3 { include wasi:cli/imports@0.3.0-rc-2026-02-09; + include wasi:tls/imports@0.3.0-draft; import wasi:http/types@0.3.0-rc-2026-02-09; import wasi:http/client@0.3.0-rc-2026-02-09; import wasi:http/handler@0.3.0-rc-2026-02-09; @@ -14,7 +15,10 @@ wit_bindgen::generate!({ export wasi:cli/run@0.3.0-rc-2026-02-09; } ", - path: "../wasi-http/src/p3/wit", + path: [ + "../wasi-http/src/p3/wit", + "../wasi-tls/src/p3/wit", + ], world: "wasmtime:test/testp3", default_bindings_module: "test_programs::p3", pub_export_macro: true, diff --git a/crates/wasi-tls-nativetls/tests/main.rs b/crates/wasi-tls-nativetls/tests/main.rs index 0aa344a815ec..be4ab471caba 100644 --- a/crates/wasi-tls-nativetls/tests/main.rs +++ b/crates/wasi-tls-nativetls/tests/main.rs @@ -60,9 +60,9 @@ macro_rules! assert_test_exists { }; } -test_programs_artifacts::foreach_tls!(assert_test_exists); +test_programs_artifacts::foreach_p2_tls!(assert_test_exists); #[tokio::test(flavor = "multi_thread")] -async fn tls_sample_application() -> Result<()> { - run_test(test_programs_artifacts::TLS_SAMPLE_APPLICATION_COMPONENT).await +async fn p2_tls_sample_application() -> Result<()> { + run_test(test_programs_artifacts::P2_TLS_SAMPLE_APPLICATION_COMPONENT).await } diff --git a/crates/wasi-tls/Cargo.toml b/crates/wasi-tls/Cargo.toml index a94af76ca3c5..2aa1f0c0b1ff 100644 --- a/crates/wasi-tls/Cargo.toml +++ b/crates/wasi-tls/Cargo.toml @@ -11,6 +11,10 @@ description = "Wasmtime implementation of the wasi-tls API" [lints] workspace = true +[features] +default = [] +p3 = ["wasmtime-wasi/p3", "wasmtime/component-model-async"] + [dependencies] bytes = { workspace = true } tokio = { workspace = true, features = [ @@ -19,6 +23,7 @@ tokio = { workspace = true, features = [ "time", "io-util", ] } +tracing = { workspace = true } wasmtime = { workspace = true, features = ["runtime", "component-model"] } wasmtime-wasi = { workspace = true } diff --git a/crates/wasi-tls/src/lib.rs b/crates/wasi-tls/src/lib.rs index 0f011f0bde29..cc14f379bef5 100644 --- a/crates/wasi-tls/src/lib.rs +++ b/crates/wasi-tls/src/lib.rs @@ -74,6 +74,8 @@ use wasmtime::component::{HasData, ResourceTable}; pub mod bindings; mod host; mod io; +#[cfg(feature = "p3")] +pub mod p3; mod rustls; pub use bindings::types::LinkOptions; diff --git a/crates/wasi-tls/src/p3/bindings.rs b/crates/wasi-tls/src/p3/bindings.rs new file mode 100644 index 000000000000..5dc68fb01be7 --- /dev/null +++ b/crates/wasi-tls/src/p3/bindings.rs @@ -0,0 +1,23 @@ +//! Raw bindings to the `wasi:tls` package. + +#[expect(missing_docs, reason = "generated code")] +mod generated { + wasmtime::component::bindgen!({ + path: "src/p3/wit", + world: "wasi:tls/imports", + imports: { + "wasi:tls/client.[static]handshake.finish": trappable | tracing | store, + "wasi:tls/client.connect": trappable | tracing | store, + "wasi:tls/server.[static]handshake.finish": trappable | tracing | store, + default: trappable | tracing + }, + with: { + "wasi:tls/client.handshake": crate::p3::ClientHandshake, + "wasi:tls/client.hello": crate::p3::ClientHello, + "wasi:tls/server.handshake": crate::p3::ServerHandshake, + "wasi:tls/types.certificate": crate::p3::Certificate, + }, + }); +} + +pub use self::generated::wasi::*; diff --git a/crates/wasi-tls/src/p3/host/client.rs b/crates/wasi-tls/src/p3/host/client.rs new file mode 100644 index 000000000000..440bd6dd5a67 --- /dev/null +++ b/crates/wasi-tls/src/p3/host/client.rs @@ -0,0 +1,259 @@ +use super::{ + CiphertextConsumer, CiphertextProducer, PlaintextConsumer, PlaintextProducer, ResultProducer, + mk_delete, mk_get, mk_get_mut, mk_push, +}; +use crate::p3::bindings::tls::client::{ + Handshake, Hello, Host, HostHandshake, HostHandshakeWithStore, HostHello, HostWithStore, +}; +use crate::p3::bindings::tls::types::Certificate; +use crate::p3::{TlsStream, TlsStreamClientArc, WasiTls, WasiTlsCtxView}; +use anyhow::{Context as _, anyhow, bail}; +use core::mem; +use core::net::{IpAddr, Ipv4Addr}; +use core::pin::{Pin, pin}; +use core::task::{Context, Poll}; +use rustls::client::ResolvesClientCert; +use rustls::pki_types::ServerName; +use std::sync::{Arc, Mutex}; +use tokio::sync::oneshot; +use wasmtime::StoreContextMut; +use wasmtime::component::{Access, FutureProducer, FutureReader, Resource, StreamReader}; + +mk_push!(Hello, push_hello, "client hello"); +mk_get_mut!(Hello, get_hello_mut, "client hello"); +mk_delete!(Hello, delete_hello, "client hello"); + +mk_push!(Handshake, push_handshake, "client handshake"); +mk_get!(Handshake, get_handshake, "client handshake"); +mk_delete!(Handshake, delete_handshake, "client handshake"); + +#[derive(Default)] +enum ConnectProducer { + Pending { + stream: TlsStreamClientArc, + error_rx: oneshot::Receiver, + getter: fn(&mut T) -> WasiTlsCtxView<'_>, + }, + #[default] + Exhausted, +} + +impl FutureProducer for ConnectProducer +where + D: 'static, +{ + type Item = Result, ()>; + + fn poll_produce( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut store: StoreContextMut, + finish: bool, + ) -> Poll>> { + let this = self.get_mut(); + let Self::Pending { + stream, + mut error_rx, + getter, + } = mem::take(this) + else { + return Poll::Ready(Err(anyhow!("polled after ready"))); + }; + if let Poll::Ready(..) = pin!(&mut error_rx).poll(cx) { + return Poll::Ready(Ok(Some(Err(())))); + } + + { + let mut stream_lock = stream.lock(); + let TlsStream { conn, read_tls, .. } = stream_lock.as_deref_mut().unwrap(); + if conn.peer_certificates().is_none() || conn.negotiated_cipher_suite().is_none() { + if !finish { + *read_tls = Some(cx.waker().clone()); + } + drop(stream_lock); + *this = Self::Pending { + stream, + error_rx, + getter, + }; + if finish { + return Poll::Ready(Ok(None)); + } + return Poll::Pending; + } + }; + + let WasiTlsCtxView { table, .. } = getter(store.data_mut()); + + let handshake = Handshake { stream, error_rx }; + let handshake = push_handshake(table, handshake)?; + + Poll::Ready(Ok(Some(Ok(handshake)))) + } +} + +#[derive(Debug)] +struct CertificateResolver; + +impl ResolvesClientCert for CertificateResolver { + fn resolve( + &self, + _root_hint_subjects: &[&[u8]], + _sigschemes: &[rustls::SignatureScheme], + ) -> Option> { + // TODO: implement + None + } + + fn has_certs(&self) -> bool { + false + } +} + +impl Host for WasiTlsCtxView<'_> {} + +impl HostHello for WasiTlsCtxView<'_> { + fn new(&mut self) -> wasmtime::Result> { + push_hello(&mut self.table, Hello::default()) + } + + fn set_server_name( + &mut self, + hello: Resource, + server_name: String, + ) -> wasmtime::Result> { + let hello = get_hello_mut(&mut self.table, &hello)?; + let Ok(server_name) = server_name.try_into() else { + return Ok(Err(())); + }; + hello.server_name = Some(server_name); + Ok(Ok(())) + } + + fn set_alpn_ids( + &mut self, + hello: Resource, + alpn_ids: Vec>, + ) -> wasmtime::Result<()> { + let hello = get_hello_mut(&mut self.table, &hello)?; + hello.alpn_ids = Some(alpn_ids); + Ok(()) + } + + fn set_cipher_suites( + &mut self, + hello: Resource, + cipher_suites: Vec, + ) -> wasmtime::Result<()> { + let hello = get_hello_mut(&mut self.table, &hello)?; + hello.cipher_suites = cipher_suites; + Ok(()) + } + + fn drop(&mut self, hello: Resource) -> wasmtime::Result<()> { + delete_hello(&mut self.table, hello)?; + Ok(()) + } +} + +impl HostWithStore for WasiTls { + fn connect( + mut store: Access, + hello: Resource, + incoming: StreamReader, + ) -> wasmtime::Result<( + StreamReader, + FutureReader, ()>>, + )> { + let Hello { + server_name, + alpn_ids, + cipher_suites, + } = delete_hello(store.get().table, hello)?; + + let roots = rustls::RootCertStore { + roots: webpki_roots::TLS_SERVER_ROOTS.into(), + }; + if !cipher_suites.is_empty() { + // TODO: implement + bail!("custom cipher suites not supported yet") + } + let mut config = rustls::ClientConfig::builder() + .with_root_certificates(roots) + .with_client_cert_resolver(Arc::new(CertificateResolver)); + if let Some(alpn_ids) = alpn_ids { + config.alpn_protocols = alpn_ids; + } + let server_name = if let Some(server_name) = server_name { + server_name + } else { + config.enable_sni = false; + ServerName::IpAddress(IpAddr::V4(Ipv4Addr::UNSPECIFIED).into()) + }; + let conn = rustls::ClientConnection::new(Arc::from(config), server_name) + .context("failed to construct rustls client connection")?; + let (error_tx, error_rx) = oneshot::channel(); + let stream = Arc::new(Mutex::new(TlsStream::new(conn, error_tx))); + + incoming.pipe(&mut store, CiphertextConsumer(Arc::clone(&stream))); + let getter = store.getter(); + Ok(( + StreamReader::new(&mut store, CiphertextProducer(Arc::clone(&stream))), + FutureReader::new( + &mut store, + ConnectProducer::Pending { + stream, + error_rx, + getter, + }, + ), + )) + } +} + +impl HostHandshake for WasiTlsCtxView<'_> { + fn set_client_certificate( + &mut self, + _handshake: Resource, + _cert: Resource, + ) -> wasmtime::Result<()> { + todo!() + } + + fn get_server_certificate( + &mut self, + _handshake: Resource, + ) -> wasmtime::Result>> { + todo!() + } + + fn get_cipher_suite(&mut self, handshake: Resource) -> wasmtime::Result { + let Handshake { stream, .. } = get_handshake(&self.table, &handshake)?; + let mut stream = stream.lock(); + let TlsStream { conn, .. } = stream.as_deref_mut().unwrap(); + let cipher_suite = conn + .negotiated_cipher_suite() + .context("cipher suite not available")?; + Ok(cipher_suite.suite().get_u16()) + } + + fn drop(&mut self, handshake: Resource) -> wasmtime::Result<()> { + delete_handshake(&mut self.table, handshake)?; + Ok(()) + } +} + +impl HostHandshakeWithStore for WasiTls { + fn finish( + mut store: Access, + handshake: Resource, + data: StreamReader, + ) -> wasmtime::Result<(StreamReader, FutureReader>)> { + let Handshake { stream, error_rx } = delete_handshake(&mut store.get().table, handshake)?; + data.pipe(&mut store, PlaintextConsumer(Arc::clone(&stream))); + Ok(( + StreamReader::new(&mut store, PlaintextProducer(stream)), + FutureReader::new(&mut store, ResultProducer(error_rx)), + )) + } +} diff --git a/crates/wasi-tls/src/p3/host/mod.rs b/crates/wasi-tls/src/p3/host/mod.rs new file mode 100644 index 000000000000..2b48c7f0b03e --- /dev/null +++ b/crates/wasi-tls/src/p3/host/mod.rs @@ -0,0 +1,364 @@ +use crate::p3::{TlsStream, TlsStreamArc}; +use anyhow::Context as _; +use core::ops::DerefMut; +use core::pin::Pin; +use core::task::{Context, Poll, Waker}; +use std::io::{Read as _, Write as _}; +use tokio::sync::oneshot; +use wasmtime::StoreContextMut; +use wasmtime::component::{ + Destination, FutureProducer, Source, StreamConsumer, StreamProducer, StreamResult, +}; + +mod client; +mod server; +mod types; + +macro_rules! mk_push { + ($t:ty, $f:ident, $desc:literal) => { + #[track_caller] + #[inline] + pub fn $f( + table: &mut wasmtime::component::ResourceTable, + value: $t, + ) -> wasmtime::Result> { + use anyhow::Context as _; + + table + .push(value) + .context(concat!("failed to push ", $desc, " resource to table")) + } + }; +} + +macro_rules! mk_get { + ($t:ty, $f:ident, $desc:literal) => { + #[track_caller] + #[inline] + pub fn $f<'a>( + table: &'a wasmtime::component::ResourceTable, + resource: &'a wasmtime::component::Resource<$t>, + ) -> wasmtime::Result<&'a $t> { + use anyhow::Context as _; + + table + .get(resource) + .context(concat!("failed to get ", $desc, " resource from table")) + } + }; +} + +macro_rules! mk_get_mut { + ($t:ty, $f:ident, $desc:literal) => { + #[track_caller] + #[inline] + pub fn $f<'a>( + table: &'a mut wasmtime::component::ResourceTable, + resource: &'a wasmtime::component::Resource<$t>, + ) -> wasmtime::Result<&'a mut $t> { + use anyhow::Context as _; + + table.get_mut(resource).context(concat!( + "failed to get ", + $desc, + " resource from table" + )) + } + }; +} + +macro_rules! mk_delete { + ($t:ty, $f:ident, $desc:literal) => { + #[track_caller] + #[inline] + pub fn $f( + table: &mut wasmtime::component::ResourceTable, + resource: wasmtime::component::Resource<$t>, + ) -> wasmtime::Result<$t> { + use anyhow::Context as _; + + table.delete(resource).context(concat!( + "failed to delete ", + $desc, + " resource from table" + )) + } + }; +} + +pub(crate) use {mk_delete, mk_get, mk_get_mut, mk_push}; + +struct CiphertextConsumer(TlsStreamArc); + +impl StreamConsumer for CiphertextConsumer +where + T: DerefMut> + Send + 'static, +{ + type Item = u8; + + fn poll_consume( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + store: StoreContextMut, + src: Source, + finish: bool, + ) -> Poll> { + let mut stream = self.0.lock(); + let TlsStream { + conn, + error_tx, + read_tls, + ciphertext_consumer, + ciphertext_producer, + plaintext_consumer, + plaintext_producer, + .. + } = stream.as_deref_mut().unwrap(); + if error_tx.is_none() { + return Poll::Ready(Ok(StreamResult::Dropped)); + } + + if !conn.wants_read() { + if finish { + return Poll::Ready(Ok(StreamResult::Cancelled)); + } + *ciphertext_consumer = Some(cx.waker().clone()); + return Poll::Pending; + } + + let mut src = src.as_direct(store); + if src.remaining().is_empty() { + return Poll::Ready(Ok(StreamResult::Completed)); + } + let n = conn.read_tls(&mut src)?; + debug_assert_ne!(n, 0); + read_tls.take().map(Waker::wake); + + let state = match conn.process_new_packets() { + Ok(state) => state, + Err(err) => { + _ = error_tx.take().unwrap().send(err); + ciphertext_producer.take().map(Waker::wake); + plaintext_consumer.take().map(Waker::wake); + plaintext_producer.take().map(Waker::wake); + return Poll::Ready(Ok(StreamResult::Dropped)); + } + }; + + if state.plaintext_bytes_to_read() > 0 { + plaintext_producer.take().map(Waker::wake); + } + + if state.tls_bytes_to_write() > 0 { + ciphertext_producer.take().map(Waker::wake); + } + + if state.peer_has_closed() { + // even if there are no bytes to read, the producer may be pending + plaintext_producer.take().map(Waker::wake); + return Poll::Ready(Ok(StreamResult::Dropped)); + } + + Poll::Ready(Ok(StreamResult::Completed)) + } +} + +struct PlaintextProducer(TlsStreamArc); + +impl StreamProducer for PlaintextProducer +where + T: DerefMut> + Send + 'static, +{ + type Item = u8; + type Buffer = Option; // unused + + fn poll_produce<'a>( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + store: StoreContextMut<'a, D>, + dst: Destination<'a, Self::Item, Self::Buffer>, + finish: bool, + ) -> Poll> { + let mut stream = self.0.lock(); + let TlsStream { + conn, + error_tx, + ciphertext_consumer, + plaintext_producer, + .. + } = stream.as_deref_mut().unwrap(); + if error_tx.is_none() { + return Poll::Ready(Ok(StreamResult::Dropped)); + } + + let state = conn.process_new_packets().context("unhandled TLS error")?; + if state.plaintext_bytes_to_read() == 0 { + if state.peer_has_closed() { + return Poll::Ready(Ok(StreamResult::Dropped)); + } else if finish { + return Poll::Ready(Ok(StreamResult::Cancelled)); + } + *plaintext_producer = Some(cx.waker().clone()); + return Poll::Pending; + } + + let mut dst = dst.as_direct(store, state.plaintext_bytes_to_read()); + let buf = dst.remaining(); + if buf.is_empty() { + return Poll::Ready(Ok(StreamResult::Completed)); + } + let n = conn.reader().read(buf)?; + debug_assert_ne!(n, 0); + dst.mark_written(n); + if conn.wants_read() { + ciphertext_consumer.take().map(Waker::wake); + } + Poll::Ready(Ok(StreamResult::Completed)) + } +} + +struct PlaintextConsumer(TlsStreamArc) +where + T: DerefMut> + Send + 'static; + +impl Drop for PlaintextConsumer +where + T: DerefMut> + Send + 'static, +{ + fn drop(&mut self) { + let mut stream = self.0.lock(); + let TlsStream { + conn, + close_notify, + ciphertext_producer, + .. + } = stream.as_deref_mut().unwrap(); + conn.send_close_notify(); + *close_notify = true; + ciphertext_producer.take().map(Waker::wake); + } +} + +impl StreamConsumer for PlaintextConsumer +where + T: DerefMut> + Send + 'static, + U: 'static, +{ + type Item = u8; + + fn poll_consume( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + store: StoreContextMut, + src: Source, + finish: bool, + ) -> Poll> { + let mut stream = self.0.lock(); + let TlsStream { + conn, + error_tx, + ciphertext_producer, + plaintext_consumer, + .. + } = stream.as_deref_mut().unwrap(); + if error_tx.is_none() { + return Poll::Ready(Ok(StreamResult::Dropped)); + } + + let mut src = src.as_direct(store); + if src.remaining().is_empty() { + return Poll::Ready(Ok(StreamResult::Completed)); + } + + let mut dst = conn.writer(); + let n = dst.write(src.remaining())?; + if n == 0 { + if finish { + return Poll::Ready(Ok(StreamResult::Cancelled)); + } + *plaintext_consumer = Some(cx.waker().clone()); + return Poll::Pending; + } + src.mark_read(n); + dst.flush()?; + if conn.wants_write() { + ciphertext_producer.take().map(Waker::wake); + } + Poll::Ready(Ok(StreamResult::Completed)) + } +} + +struct CiphertextProducer(TlsStreamArc); + +impl StreamProducer for CiphertextProducer +where + T: DerefMut> + Send + 'static, +{ + type Item = u8; + type Buffer = Option; // unused + + fn poll_produce<'a>( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + store: StoreContextMut<'a, D>, + dst: Destination<'a, Self::Item, Self::Buffer>, + finish: bool, + ) -> Poll> { + let mut stream = self.0.lock(); + let TlsStream { + conn, + error_tx, + close_notify, + ciphertext_consumer, + ciphertext_producer, + plaintext_consumer, + .. + } = stream.as_deref_mut().unwrap(); + if error_tx.is_none() { + return Poll::Ready(Ok(StreamResult::Dropped)); + } + + if !conn.wants_write() { + if *close_notify { + return Poll::Ready(Ok(StreamResult::Dropped)); + } else if finish { + return Poll::Ready(Ok(StreamResult::Cancelled)); + } + *ciphertext_producer = Some(cx.waker().clone()); + plaintext_consumer.take().map(Waker::wake); + return Poll::Pending; + } + + let state = conn.process_new_packets().context("unhandled TLS error")?; + let mut dst = dst.as_direct(store, state.tls_bytes_to_write()); + if dst.remaining().is_empty() { + return Poll::Ready(Ok(StreamResult::Completed)); + } + let n = conn.write_tls(&mut dst)?; + debug_assert_ne!(n, 0); + if conn.wants_read() { + ciphertext_consumer.take().map(Waker::wake); + } + Poll::Ready(Ok(StreamResult::Completed)) + } +} + +struct ResultProducer(oneshot::Receiver); + +impl FutureProducer for ResultProducer { + type Item = Result<(), ()>; + + fn poll_produce( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + _store: StoreContextMut, + finish: bool, + ) -> Poll>> { + match Pin::new(&mut self.0).poll(cx) { + Poll::Ready(Ok(_err)) => Poll::Ready(Ok(Some(Err(())))), + Poll::Ready(Err(..)) => Poll::Ready(Ok(Some(Ok(())))), + Poll::Pending if finish => Poll::Ready(Ok(None)), + Poll::Pending => Poll::Pending, + } + } +} diff --git a/crates/wasi-tls/src/p3/host/server.rs b/crates/wasi-tls/src/p3/host/server.rs new file mode 100644 index 000000000000..94ec6481a477 --- /dev/null +++ b/crates/wasi-tls/src/p3/host/server.rs @@ -0,0 +1,284 @@ +#![expect(unused, reason = "WIP")] + +use super::{PlaintextConsumer, PlaintextProducer, ResultProducer, mk_delete, mk_get, mk_push}; +use crate::p3::bindings::tls::server::{ + Handshake, Host, HostHandshake, HostHandshakeWithStore, HostWithStore, +}; +use crate::p3::bindings::tls::types::Certificate; +use crate::p3::{TlsStream, TlsStreamServerArc, WasiTls, WasiTlsCtxView}; +use anyhow::{Context as _, anyhow}; +use core::mem; +use core::pin::Pin; +use core::task::{Context, Poll}; +use rustls::server::ResolvesServerCert; +use std::sync::{Arc, Mutex}; +use tokio::sync::oneshot; +use wasmtime::StoreContextMut; +use wasmtime::component::{ + Access, Accessor, Destination, FutureReader, Resource, Source, StreamConsumer, StreamProducer, + StreamReader, StreamResult, +}; + +mk_delete!(Handshake, delete_handshake, "server handshake"); +mk_get!(Handshake, get_handshake, "server handshake"); +mk_push!(Handshake, push_handshake, "server handshake"); + +enum CiphertextConsumer { + Pending { + acceptor: rustls::server::Acceptor, + tx: oneshot::Sender< + Result< + ( + rustls::server::Accepted, + oneshot::Sender, + ), + rustls::Error, + >, + >, + }, + Accepted(oneshot::Receiver), + Active(super::CiphertextConsumer), + Corrupted, +} + +impl StreamConsumer for CiphertextConsumer { + type Item = u8; + + fn poll_consume( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + store: StoreContextMut, + src: Source, + finish: bool, + ) -> Poll> { + let this = self.get_mut(); + match mem::replace(this, Self::Corrupted) { + Self::Pending { mut acceptor, tx } => { + let mut src = src.as_direct(store); + if src.remaining().is_empty() { + return Poll::Ready(Ok(StreamResult::Completed)); + } + acceptor.read_tls(&mut src)?; + match acceptor.accept() { + Ok(None) => { + *this = Self::Pending { acceptor, tx }; + Poll::Ready(Ok(StreamResult::Completed)) + } + Ok(Some(accepted)) => { + let (stream_tx, stream_rx) = oneshot::channel(); + _ = tx.send(Ok((accepted, stream_tx))); + *this = Self::Accepted(stream_rx); + Poll::Ready(Ok(StreamResult::Completed)) + } + Err(err) => { + _ = tx.send(Err(err)); + Poll::Ready(Ok(StreamResult::Dropped)) + } + } + } + Self::Accepted(mut rx) => match Pin::new(&mut rx).poll(cx) { + Poll::Ready(Ok(stream)) => { + *this = Self::Active(super::CiphertextConsumer(stream)); + Poll::Ready(Ok(StreamResult::Completed)) + } + Poll::Ready(Err(..)) => Poll::Ready(Ok(StreamResult::Dropped)), + Poll::Pending if finish => { + *this = Self::Accepted(rx); + Poll::Ready(Ok(StreamResult::Cancelled)) + } + Poll::Pending => { + *this = Self::Accepted(rx); + Poll::Ready(Ok(StreamResult::Cancelled)) + } + }, + Self::Active(ref mut conn) => Pin::new(conn).poll_consume(cx, store, src, finish), + Self::Corrupted => Poll::Ready(Err(anyhow!("corrupted stream consumer state"))), + } + } +} + +enum CiphertextProducer { + Pending(oneshot::Receiver), + Active(super::CiphertextProducer), + Corrupted, +} + +impl StreamProducer for CiphertextProducer { + type Item = u8; + type Buffer = Option; // unused + + fn poll_produce<'a>( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + store: StoreContextMut<'a, D>, + dst: Destination<'a, Self::Item, Self::Buffer>, + finish: bool, + ) -> Poll> { + let this = self.get_mut(); + match mem::replace(this, Self::Corrupted) { + Self::Pending(mut rx) => match Pin::new(&mut rx).poll(cx) { + Poll::Ready(Ok(stream)) => { + *this = Self::Active(super::CiphertextProducer(stream)); + Poll::Ready(Ok(StreamResult::Completed)) + } + Poll::Ready(Err(..)) => Poll::Ready(Ok(StreamResult::Dropped)), + Poll::Pending if finish => { + *this = Self::Pending(rx); + Poll::Ready(Ok(StreamResult::Cancelled)) + } + Poll::Pending => { + *this = Self::Pending(rx); + Poll::Ready(Ok(StreamResult::Cancelled)) + } + }, + Self::Active(ref mut conn) => Pin::new(conn).poll_produce(cx, store, dst, finish), + Self::Corrupted => Poll::Ready(Err(anyhow!("corrupted stream producer state"))), + } + } +} + +#[derive(Debug)] +struct CertificateResolver; + +impl ResolvesServerCert for CertificateResolver { + fn resolve( + &self, + hello: rustls::server::ClientHello, + ) -> Option> { + // TODO: Implement + None + } +} + +impl Host for WasiTlsCtxView<'_> {} + +impl HostWithStore for WasiTls { + async fn accept( + store: &Accessor, + incoming: StreamReader, + ) -> wasmtime::Result, Resource), ()>> { + let (accept_tx, accept_rx) = oneshot::channel(); + store.with(|store| { + incoming.pipe( + store, + CiphertextConsumer::Pending { + acceptor: rustls::server::Acceptor::default(), + tx: accept_tx, + }, + ); + }); + let (accepted, consumer_tx) = match accept_rx + .await + .context("oneshot sender dropped unexpectedly")? + { + Ok((accepted, consumer_tx)) => (accepted, consumer_tx), + Err(_err) => return Ok(Err(())), + }; + let (producer_tx, producer_rx) = oneshot::channel(); + store.with(|mut store| { + let handshake = push_handshake( + store.get().table, + Handshake { + accepted, + consumer_tx, + producer_tx, + }, + )?; + Ok(Ok(( + StreamReader::new(&mut store, CiphertextProducer::Pending(producer_rx)), + handshake, + ))) + }) + } +} + +impl HostHandshake for WasiTlsCtxView<'_> { + fn set_server_certificate( + &mut self, + handshake: Resource, + cert: Resource, + ) -> wasmtime::Result<()> { + todo!() + } + + fn get_client_certificate( + &mut self, + handshake: Resource, + ) -> wasmtime::Result, ()>>> { + todo!() + } + + fn get_server_name( + &mut self, + handshake: Resource, + ) -> wasmtime::Result> { + let handshake = get_handshake(&self.table, &handshake)?; + let hello = handshake.accepted.client_hello(); + let server_name = hello.server_name().map(Into::into); + Ok(server_name) + } + + fn get_alpn_ids( + &mut self, + handshake: Resource, + ) -> wasmtime::Result>>> { + let handshake = get_handshake(&self.table, &handshake)?; + let hello = handshake.accepted.client_hello(); + let alpn = hello.alpn().map(|alpn| alpn.map(Into::into).collect()); + Ok(alpn) + } + + fn get_cipher_suites(&mut self, handshake: Resource) -> wasmtime::Result> { + let handshake = get_handshake(&self.table, &handshake)?; + let hello = handshake.accepted.client_hello(); + let cipher_suites = hello + .cipher_suites() + .into_iter() + .map(rustls::CipherSuite::get_u16) + .collect(); + Ok(cipher_suites) + } + + fn set_cipher_suite( + &mut self, + handshake: Resource, + cipher_suite: u16, + ) -> wasmtime::Result<()> { + todo!() + } + + fn drop(&mut self, handshake: Resource) -> wasmtime::Result<()> { + delete_handshake(&mut self.table, handshake)?; + Ok(()) + } +} + +impl HostHandshakeWithStore for WasiTls { + fn finish( + mut store: Access, + handshake: Resource, + data: StreamReader, + ) -> wasmtime::Result<(StreamReader, FutureReader>)> { + let Handshake { + accepted, + consumer_tx, + producer_tx, + } = delete_handshake(&mut store.get().table, handshake)?; + // TODO: configure + let config = rustls::ServerConfig::builder() + .with_no_client_auth() + .with_cert_resolver(Arc::new(CertificateResolver)); + let conn = accepted + .into_connection(Arc::from(config)) + .context("failed to construct rustls server connection")?; + let (error_tx, error_rx) = oneshot::channel(); + let stream = Arc::new(Mutex::new(TlsStream::new(conn, error_tx))); + data.pipe(&mut store, PlaintextConsumer(Arc::clone(&stream))); + _ = consumer_tx.send(Arc::clone(&stream)); + _ = producer_tx.send(Arc::clone(&stream)); + Ok(( + StreamReader::new(&mut store, PlaintextProducer(stream)), + FutureReader::new(&mut store, ResultProducer(error_rx)), + )) + } +} diff --git a/crates/wasi-tls/src/p3/host/types.rs b/crates/wasi-tls/src/p3/host/types.rs new file mode 100644 index 000000000000..2c71ef0fbb5e --- /dev/null +++ b/crates/wasi-tls/src/p3/host/types.rs @@ -0,0 +1,15 @@ +use super::mk_delete; +use crate::p3::WasiTlsCtxView; +use crate::p3::bindings::tls::types::{Certificate, Host, HostCertificate}; +use wasmtime::component::Resource; + +mk_delete!(Certificate, delete_certificate, "certificate"); + +impl Host for WasiTlsCtxView<'_> {} + +impl HostCertificate for WasiTlsCtxView<'_> { + fn drop(&mut self, cert: Resource) -> wasmtime::Result<()> { + delete_certificate(&mut self.table, cert)?; + Ok(()) + } +} diff --git a/crates/wasi-tls/src/p3/mod.rs b/crates/wasi-tls/src/p3/mod.rs new file mode 100644 index 000000000000..e3be4924a4b4 --- /dev/null +++ b/crates/wasi-tls/src/p3/mod.rs @@ -0,0 +1,180 @@ +//! Experimental, unstable and incomplete implementation of wasip3 version of `wasi:tls`. +//! +//! This module is under heavy development. +//! It is not compliant with semver and is not ready +//! for production use. +//! +//! Bug and security fixes limited to wasip3 will not be given patch releases. +//! +//! Documentation of this module may be incorrect or out-of-sync with the implementation. + +pub mod bindings; +mod host; + +use core::task::Waker; +use std::sync::{Arc, Mutex}; + +use bindings::tls::{client, server, types}; +use rustls::pki_types::ServerName; +use tokio::sync::oneshot; +use wasmtime::component::{HasData, Linker, ResourceTable}; + +/// The type for which this crate implements the `wasi:tls` interfaces. +pub struct WasiTls; + +impl HasData for WasiTls { + type Data<'a> = WasiTlsCtxView<'a>; +} + +/// A trait which provides internal WASI TLS state. +pub trait WasiTlsCtx: Send {} + +/// Default implementation of [WasiTlsCtx]. +#[derive(Clone, Default)] +pub struct DefaultWasiTlsCtx; + +impl WasiTlsCtx for DefaultWasiTlsCtx {} + +/// View into [WasiTlsCtx] implementation and [ResourceTable]. +pub struct WasiTlsCtxView<'a> { + /// Mutable reference to the WASI TLS context. + pub ctx: &'a mut dyn WasiTlsCtx, + + /// Mutable reference to table used to manage resources. + pub table: &'a mut ResourceTable, +} + +/// A trait which provides internal WASI TLS state. +pub trait WasiTlsView: Send { + /// Return a [WasiTlsCtxView] from mutable reference to self. + fn tls(&mut self) -> WasiTlsCtxView<'_>; +} + +/// Add all interfaces from this module into the `linker` provided. +/// +/// This function will add all interfaces implemented by this module to the +/// [`Linker`], which corresponds to the `wasi:tls/imports` world supported by +/// this module. +/// +/// # Example +/// +/// ``` +/// use wasmtime::{Engine, Result, Store, Config}; +/// use wasmtime::component::{Linker, ResourceTable}; +/// use wasmtime_wasi_tls::p3::{DefaultWasiTlsCtx, WasiTlsCtxView, WasiTlsView}; +/// +/// fn main() -> Result<()> { +/// let mut config = Config::new(); +/// config.async_support(true); +/// config.wasm_component_model_async(true); +/// let engine = Engine::new(&config)?; +/// +/// let mut linker = Linker::::new(&engine); +/// wasmtime_wasi_tls::p3::add_to_linker(&mut linker)?; +/// // ... add any further functionality to `linker` if desired ... +/// +/// let mut store = Store::new( +/// &engine, +/// MyState::default(), +/// ); +/// +/// // ... use `linker` to instantiate within `store` ... +/// +/// Ok(()) +/// } +/// +/// #[derive(Default)] +/// struct MyState { +/// tls: DefaultWasiTlsCtx, +/// table: ResourceTable, +/// } +/// +/// impl WasiTlsView for MyState { +/// fn tls(&mut self) -> WasiTlsCtxView<'_> { +/// WasiTlsCtxView { +/// ctx: &mut self.tls, +/// table: &mut self.table, +/// } +/// } +/// } +/// ``` +pub fn add_to_linker(linker: &mut Linker) -> wasmtime::Result<()> +where + T: WasiTlsView + 'static, +{ + client::add_to_linker::<_, WasiTls>(linker, T::tls)?; + server::add_to_linker::<_, WasiTls>(linker, T::tls)?; + types::add_to_linker::<_, WasiTls>(linker, T::tls)?; + Ok(()) +} + +/// Client hello +#[derive(Clone, Default, Eq, PartialEq, Hash)] +pub struct ClientHello { + /// Server name indicator. + pub server_name: Option>, + /// ALPN IDs + pub alpn_ids: Option>>, + /// Cipher suites + pub cipher_suites: Vec, +} + +/// Server hello +#[derive(Clone, Eq, PartialEq, Hash)] +pub struct ServerHello { + /// Cipher suite + pub cipher_suite: u16, +} + +impl ServerHello { + /// Constructs a new server hello message + pub fn new(cipher_suite: u16) -> Self { + Self { cipher_suite } + } +} + +type TlsStreamArc = Arc>>; +type TlsStreamClientArc = TlsStreamArc; +type TlsStreamServerArc = TlsStreamArc; + +/// Client handshake +pub struct ClientHandshake { + stream: TlsStreamClientArc, + error_rx: oneshot::Receiver, +} + +/// Server handshake +pub struct ServerHandshake { + accepted: rustls::server::Accepted, + consumer_tx: oneshot::Sender, + producer_tx: oneshot::Sender, +} + +/// Certificate +pub struct Certificate; + +struct TlsStream { + conn: T, + error_tx: Option>, + close_notify: bool, + read_tls: Option, + ciphertext_consumer: Option, + ciphertext_producer: Option, + plaintext_consumer: Option, + plaintext_producer: Option, +} + +impl TlsStream { + fn new(conn: T, error_tx: oneshot::Sender) -> Self { + Self { + conn, + error_tx: Some(error_tx), + close_notify: false, + read_tls: None, + plaintext_producer: None, + plaintext_consumer: None, + ciphertext_producer: None, + ciphertext_consumer: None, + } + } +} diff --git a/crates/wasi-tls/src/p3/wit/deps/tls/client.wit b/crates/wasi-tls/src/p3/wit/deps/tls/client.wit new file mode 100644 index 000000000000..2aaaf695fa54 --- /dev/null +++ b/crates/wasi-tls/src/p3/wit/deps/tls/client.wit @@ -0,0 +1,38 @@ +interface client { + use types.{certificate}; + + resource hello { + /// Constructs a new ClientHello message. + constructor(); + + /// Sets the server name indicator. + set-server-name: func(server-name: string) -> result; + + /// Sets the ALPN IDs advertised by the client. + set-alpn-ids: func(alpn-ids: list>); + + /// Sets a list of the symmetric cipher options supported by + /// the client, specifically the record protection algorithm + /// (including secret key length) and a hash to be used with HKDF, in + /// descending order of client preference. + /// + /// If this list is empty, the implementation must use a reasonable default. + set-cipher-suites: func(cipher-suites: list); + } + + resource handshake { + set-client-certificate: func(cert: certificate); + + get-server-certificate: func() -> option; + + /// Gets the single cipher suite selected by the server from + /// the list in ClientHello.cipher_suites. + get-cipher-suite: func() -> u16; + + /// Closing the `data` stream will trigger `close_notify`. + finish: static func(this: handshake, data: stream) -> tuple, future>; + } + + /// Initiate the client TLS handshake + connect: func(hello: hello, incoming: stream) -> tuple, future>>; +} diff --git a/crates/wasi-tls/src/p3/wit/deps/tls/server.wit b/crates/wasi-tls/src/p3/wit/deps/tls/server.wit new file mode 100644 index 000000000000..6d0f23d33b13 --- /dev/null +++ b/crates/wasi-tls/src/p3/wit/deps/tls/server.wit @@ -0,0 +1,36 @@ +interface server { + use types.{certificate}; + + resource handshake { + set-server-certificate: func(cert: certificate); + + get-client-certificate: func() -> future>; + + /// Gets the server name indicator. + /// Returns `none` if the client did not supply a SNI. + get-server-name: func() -> option; + + /// Gets the ALPN IDs advertised by the client. + /// Returns `none` if the client did not include an ALPN extension. + get-alpn-ids: func() -> option>>; + + /// Gets a list of the symmetric cipher options supported by + /// the client, specifically the record protection algorithm + /// (including secret key length) and a hash to be used with HKDF, in + /// descending order of client preference. + get-cipher-suites: func() -> list; + + /// Selects the cipher-suite from + /// the list returned by `get-cipher-suites` + /// + /// If this is not called before `finish`, implementation + /// will select appropriate cipher suite. + set-cipher-suite: func(cipher-suite: u16); + + /// Closing the `data` stream will trigger `close_notify`. + finish: static func(this: handshake, data: stream) -> tuple, future>; + } + + /// Accept the client TLS handshake + accept: async func(incoming: stream) -> result, handshake>>; +} diff --git a/crates/wasi-tls/src/p3/wit/deps/tls/types.wit b/crates/wasi-tls/src/p3/wit/deps/tls/types.wit new file mode 100644 index 000000000000..a0bcecae7da7 --- /dev/null +++ b/crates/wasi-tls/src/p3/wit/deps/tls/types.wit @@ -0,0 +1,5 @@ +interface types { + resource certificate { + // TODO: define + } +} diff --git a/crates/wasi-tls/src/p3/wit/deps/tls/world.wit b/crates/wasi-tls/src/p3/wit/deps/tls/world.wit new file mode 100644 index 000000000000..f605ce358457 --- /dev/null +++ b/crates/wasi-tls/src/p3/wit/deps/tls/world.wit @@ -0,0 +1,7 @@ +package wasi:tls@0.3.0-draft; + +world imports { + import client; + import server; + import types; +} diff --git a/crates/wasi-tls/src/p3/wit/world.wit b/crates/wasi-tls/src/p3/wit/world.wit new file mode 100644 index 000000000000..51e0e7e8cc58 --- /dev/null +++ b/crates/wasi-tls/src/p3/wit/world.wit @@ -0,0 +1,2 @@ +// We actually don't use this; it's just to let bindgen! find the corresponding world in wit/deps. +package wasmtime:wasi-tls; diff --git a/crates/wasi-tls/tests/main.rs b/crates/wasi-tls/tests/p2.rs similarity index 90% rename from crates/wasi-tls/tests/main.rs rename to crates/wasi-tls/tests/p2.rs index 86ddee7f888f..45d3e920c917 100644 --- a/crates/wasi-tls/tests/main.rs +++ b/crates/wasi-tls/tests/p2.rs @@ -60,9 +60,9 @@ macro_rules! assert_test_exists { }; } -test_programs_artifacts::foreach_tls!(assert_test_exists); +test_programs_artifacts::foreach_p2_tls!(assert_test_exists); #[tokio::test(flavor = "multi_thread")] -async fn tls_sample_application() -> Result<()> { - run_test(test_programs_artifacts::TLS_SAMPLE_APPLICATION_COMPONENT).await +async fn p2_tls_sample_application() -> Result<()> { + run_test(test_programs_artifacts::P2_TLS_SAMPLE_APPLICATION_COMPONENT).await } diff --git a/crates/wasi-tls/tests/p3.rs b/crates/wasi-tls/tests/p3.rs new file mode 100644 index 000000000000..a2b1184e51c5 --- /dev/null +++ b/crates/wasi-tls/tests/p3.rs @@ -0,0 +1,82 @@ +#![cfg(feature = "p3")] + +use wasmtime::component::{Component, Linker, ResourceTable}; +use wasmtime::{Result, Store, format_err}; +use wasmtime_wasi::p3::bindings::Command; +use wasmtime_wasi::{WasiCtx, WasiCtxView, WasiView}; +use wasmtime_wasi_tls::p3::{DefaultWasiTlsCtx, WasiTlsCtxView, WasiTlsView}; + +struct Ctx { + table: ResourceTable, + wasi_ctx: WasiCtx, + wasi_tls_ctx: DefaultWasiTlsCtx, +} + +impl WasiView for Ctx { + fn ctx(&mut self) -> WasiCtxView<'_> { + WasiCtxView { + ctx: &mut self.wasi_ctx, + table: &mut self.table, + } + } +} + +impl WasiTlsView for Ctx { + fn tls(&mut self) -> WasiTlsCtxView<'_> { + WasiTlsCtxView { + ctx: &mut self.wasi_tls_ctx, + table: &mut self.table, + } + } +} + +async fn run_test(path: &str) -> Result<()> { + let ctx = Ctx { + table: ResourceTable::new(), + wasi_ctx: WasiCtx::builder() + .inherit_stdout() + .inherit_stderr() + .inherit_network() + .allow_ip_name_lookup(true) + .build(), + wasi_tls_ctx: DefaultWasiTlsCtx, + }; + + let engine = test_programs_artifacts::engine(|config| { + config.async_support(true); + config.wasm_component_model_async(true); + }); + let mut store = Store::new(&engine, ctx); + + let mut linker = Linker::new(&engine); + // TODO: Remove once test components are not built for `wasm32-wasip1` + wasmtime_wasi::p2::add_to_linker_async(&mut linker) + .context("failed to link `wasi:cli@0.2.x`")?; + wasmtime_wasi::p3::add_to_linker(&mut linker).context("failed to link `wasi:cli@0.3.x`")?; + wasmtime_wasi_tls::p3::add_to_linker(&mut linker)?; + + let component = Component::from_file(&engine, path)?; + let command = Command::instantiate_async(&mut store, &component, &linker) + .await + .context("failed to instantiate `wasi:cli/command`")?; + store + .run_concurrent(async move |store| command.wasi_cli_run().call_run(store).await) + .await + .context("failed to call `wasi:cli/run#run`")? + .context("guest trapped")? + .map_err(|()| format_err!("`wasi:cli/run#run` failed")) +} + +macro_rules! assert_test_exists { + ($name:ident) => { + #[expect(unused_imports, reason = "just here to assert it exists")] + use self::$name as _; + }; +} + +test_programs_artifacts::foreach_p3_tls!(assert_test_exists); + +#[tokio::test(flavor = "multi_thread")] +async fn p3_tls_sample_application() -> Result<()> { + run_test(test_programs_artifacts::P3_TLS_SAMPLE_APPLICATION_COMPONENT).await +} From 84b35480cb9940c2910e35a9a7ca8bacc39f7ef6 Mon Sep 17 00:00:00 2001 From: Roman Volosatovs Date: Tue, 17 Feb 2026 12:56:37 +0100 Subject: [PATCH 2/5] implement reworked TLS interface Signed-off-by: Roman Volosatovs --- Cargo.lock | 1 + ci/vendor-wit.sh | 3 +- .../src/bin/p3_tls_sample_application.rs | 155 ++++---- crates/wasi-tls/Cargo.toml | 1 + crates/wasi-tls/src/p3/bindings.rs | 11 +- crates/wasi-tls/src/p3/host/client.rs | 359 +++++++----------- crates/wasi-tls/src/p3/host/mod.rs | 249 +++++++++--- crates/wasi-tls/src/p3/host/server.rs | 284 -------------- crates/wasi-tls/src/p3/host/types.rs | 18 +- crates/wasi-tls/src/p3/mod.rs | 77 ++-- .../wasi-tls/src/p3/wit/deps/tls/client.wit | 44 +-- .../wasi-tls/src/p3/wit/deps/tls/server.wit | 36 -- crates/wasi-tls/src/p3/wit/deps/tls/types.wit | 4 +- crates/wasi-tls/src/p3/wit/deps/tls/world.wit | 1 - crates/wasi-tls/tests/p3.rs | 13 +- 15 files changed, 470 insertions(+), 786 deletions(-) delete mode 100644 crates/wasi-tls/src/p3/host/server.rs delete mode 100644 crates/wasi-tls/src/p3/wit/deps/tls/server.wit diff --git a/Cargo.lock b/Cargo.lock index 1d39a6a67a05..3ceb063cb122 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5186,6 +5186,7 @@ dependencies = [ "bytes", "futures", "rustls", + "test-log", "test-programs-artifacts", "tokio", "tokio-rustls", diff --git a/ci/vendor-wit.sh b/ci/vendor-wit.sh index 18be392500c4..2dd378ae7833 100755 --- a/ci/vendor-wit.sh +++ b/ci/vendor-wit.sh @@ -65,7 +65,8 @@ mkdir -p crates/wasi-tls/wit/deps wkg get --format wit --overwrite "wasi:io@$p2" -o "crates/wasi-tls/wit/deps/io.wit" get_github wasi-tls v0.2.0-draft+505fc98 crates/wasi-tls/wit/deps/tls -get_github wasi-tls v0.3.0-draft crates/wasi-tls/src/p3/wit/deps/tls +# TODO: Use the tag when released +#get_github wasi-tls v0.3.0-draft crates/wasi-tls/src/p3/wit/deps/tls rm -rf crates/wasi-config/wit/deps mkdir -p crates/wasi-config/wit/deps diff --git a/crates/test-programs/src/bin/p3_tls_sample_application.rs b/crates/test-programs/src/bin/p3_tls_sample_application.rs index 4ecfbe4e71a4..251c43553374 100644 --- a/crates/test-programs/src/bin/p3_tls_sample_application.rs +++ b/crates/test-programs/src/bin/p3_tls_sample_application.rs @@ -1,15 +1,10 @@ -use anyhow::{Context as _, Result, anyhow, bail}; -use core::future::{Future as _, poll_fn}; -use core::pin::pin; -use core::str; -use core::task::{Poll, ready}; +use anyhow::{Context as _, Result, anyhow}; +use core::future::Future; use futures::try_join; use test_programs::p3::wasi::sockets::ip_name_lookup::resolve_addresses; use test_programs::p3::wasi::sockets::types::{IpAddress, IpSocketAddress, TcpSocket}; -use test_programs::p3::wasi::tls; -use test_programs::p3::wasi::tls::client::Hello; +use test_programs::p3::wasi::tls::client::Connector; use test_programs::p3::wit_stream; -use wit_bindgen::StreamResult; struct Component; @@ -27,63 +22,52 @@ async fn test_tls_sample_application(domain: &str, ip: IpAddress) -> Result<()> .await .context("tcp connect failed")?; + let conn = Connector::new(); + let (sock_rx, sock_rx_fut) = sock.receive(); - let hello = Hello::new(); - hello - .set_server_name(domain) - .map_err(|()| anyhow!("failed to set SNI"))?; - let (sock_tx, conn) = tls::client::connect(hello, sock_rx); - let sock_tx_fut = sock.send(sock_tx); - - let mut conn = pin!(conn.into_future()); - let mut sock_rx_fut = pin!(sock_rx_fut.into_future()); - let mut sock_tx_fut = pin!(sock_tx_fut); - let conn = poll_fn(|cx| match conn.as_mut().poll(cx) { - Poll::Ready(Ok(conn)) => Poll::Ready(Ok(conn)), - Poll::Ready(Err(())) => Poll::Ready(Err(anyhow!("tls handshake failed"))), - Poll::Pending => match sock_tx_fut.as_mut().poll(cx) { - Poll::Ready(Ok(())) => Poll::Ready(Err(anyhow!("Tx stream closed unexpectedly"))), - Poll::Ready(Err(err)) => { - Poll::Ready(Err(anyhow!("Tx stream closed with error: {err:?}"))) - } - Poll::Pending => match ready!(sock_rx_fut.as_mut().poll(cx)) { - Ok(_) => Poll::Ready(Err(anyhow!("Rx stream closed unexpectedly"))), - Err(err) => Poll::Ready(Err(anyhow!("Rx stream closed with error: {err:?}"))), - }, - }, - }) - .await?; + let (tls_rx, tls_rx_fut) = conn.receive(sock_rx); - let (mut req_tx, req_rx) = wit_stream::new(); - let (mut res_rx, result_fut) = tls::client::Handshake::finish(conn, req_rx); + let (mut data_tx, data_rx) = wit_stream::new(); + let (tls_tx, tls_tx_err_fut) = conn.send(data_rx); + let sock_tx_fut = sock.send(tls_tx); - let res = Vec::with_capacity(8192); try_join!( async { - let buf = req_tx.write_all(request.into()).await; - assert_eq!(buf, []); - drop(req_tx); + Connector::connect(conn, domain.into()) + .await + .map_err(|err| { + anyhow!(err.to_debug_string()).context("failed to establish connection") + }) + }, + async { + let buf = data_tx.write_all(request.into()).await; + assert!(buf.is_empty()); + drop(data_tx); Ok(()) }, async { - let (result, buf) = res_rx.read(res).await; - match result { - StreamResult::Complete(..) => { - drop(res_rx); - let res = String::from_utf8(buf)?; - if res.contains("HTTP/1.1 200 OK") { - Ok(()) - } else { - bail!("server did not respond with 200 OK: {res}") - } - } - StreamResult::Dropped => bail!("read dropped"), - StreamResult::Cancelled => bail!("read cancelled"), + let response = tls_rx.collect().await; + let response = String::from_utf8(response)?; + if response.contains("HTTP/1.1 200 OK") { + Ok(()) + } else { + Err(anyhow!("server did not respond with 200 OK: {response}")) } }, - async { result_fut.await.map_err(|()| anyhow!("TLS session failed")) }, - async { sock_rx_fut.await.context("TCP receipt failed") }, - async { sock_tx_fut.await.context("TCP transmit failed") }, + async { sock_rx_fut.await.context("failed to receive ciphertext") }, + async { sock_tx_fut.await.context("failed to send ciphertext") }, + async { + tls_rx_fut + .await + .map_err(|err| anyhow!(err.to_debug_string())) + .context("failed to receive plaintext") + }, + async { + tls_tx_err_fut + .await + .map_err(|err| anyhow!(err.to_debug_string())) + .context("failed to send plaintext") + }, )?; Ok(()) } @@ -92,32 +76,57 @@ async fn test_tls_sample_application(domain: &str, ip: IpAddress) -> Result<()> /// perform a TLS handshake using another unrelated domain. This should result /// in a handshake error. async fn test_tls_invalid_certificate(_domain: &str, ip: IpAddress) -> Result<()> { - const BAD_DOMAIN: &'static str = "wrongdomain.localhost"; + const BAD_DOMAIN: &str = "wrongdomain.localhost"; let sock = TcpSocket::create(ip.family()).unwrap(); sock.connect(IpSocketAddress::new(ip, PORT)) .await .context("tcp connect failed")?; + let conn = Connector::new(); + let (sock_rx, sock_rx_fut) = sock.receive(); - let hello = Hello::new(); - hello - .set_server_name(BAD_DOMAIN) - .map_err(|()| anyhow!("failed to set SNI"))?; - let (sock_tx, conn) = tls::client::connect(hello, sock_rx); - let sock_tx_fut = sock.send(sock_tx); + let (tls_rx, tls_rx_fut) = conn.receive(sock_rx); - try_join!( + let (_, data_rx) = wit_stream::new(); + let (tls_tx, tls_tx_err_fut) = conn.send(data_rx); + let sock_tx_fut = sock.send(tls_tx); + let res = try_join!( async { - match conn.await { - Err(()) => Ok(()), - Ok(_) => panic!("expecting server name mismatch"), - } + Connector::connect(conn, BAD_DOMAIN.into()) + .await + .expect("`connect` failed"); + Ok(()) }, - async { sock_rx_fut.await.context("TCP receipt failed") }, - async { sock_tx_fut.await.context("TCP transmit failed") }, - )?; - Ok(()) + async { + let response = tls_rx.collect().await; + assert_eq!(response, []); + Ok(()) + }, + async { + sock_rx_fut.await.expect("failed to receive ciphertext"); + Ok(()) + }, + async { + sock_tx_fut.await.expect("failed to send ciphertext"); + Ok(()) + }, + async { tls_rx_fut.await }, + async { tls_tx_err_fut.await }, + ); + match res { + Err(e) => { + let debug_string = e.to_debug_string(); + // We're expecting an error regarding certificates in some form or + // another. When we add more TLS backends this naive check will + // likely need to be revisited/expanded: + if debug_string.contains("certificate") || debug_string.contains("HandshakeFailure") { + return Ok(()); + } + Err(anyhow!(debug_string)) + } + Ok(_) => panic!("expecting server name mismatch"), + } } async fn try_live_endpoints<'a, Fut>(test: impl Fn(&'a str, IpAddress) -> Fut) @@ -126,7 +135,7 @@ where { // since this is testing remote endpoints to ensure system cert store works // the test uses a couple different endpoints to reduce the number of flakes - const DOMAINS: &'static [&'static str] = &[ + const DOMAINS: &[&str] = &[ "example.com", "api.github.com", "docs.wasmtime.dev", @@ -141,7 +150,7 @@ where .first() .map(|a| a.to_owned()) .ok_or_else(|| anyhow!("DNS lookup failed."))?; - test(&domain, ip).await + test(domain, ip).await })(); match result.await { diff --git a/crates/wasi-tls/Cargo.toml b/crates/wasi-tls/Cargo.toml index 2aa1f0c0b1ff..ab277ebd6c70 100644 --- a/crates/wasi-tls/Cargo.toml +++ b/crates/wasi-tls/Cargo.toml @@ -36,3 +36,4 @@ test-programs-artifacts = { workspace = true } wasmtime-wasi = { workspace = true } tokio = { workspace = true, features = ["macros"] } futures = { workspace = true } +test-log = { workspace = true } diff --git a/crates/wasi-tls/src/p3/bindings.rs b/crates/wasi-tls/src/p3/bindings.rs index 5dc68fb01be7..5958aebe47b8 100644 --- a/crates/wasi-tls/src/p3/bindings.rs +++ b/crates/wasi-tls/src/p3/bindings.rs @@ -6,16 +6,13 @@ mod generated { path: "src/p3/wit", world: "wasi:tls/imports", imports: { - "wasi:tls/client.[static]handshake.finish": trappable | tracing | store, - "wasi:tls/client.connect": trappable | tracing | store, - "wasi:tls/server.[static]handshake.finish": trappable | tracing | store, + "wasi:tls/client.[method]connector.receive": trappable | tracing | store, + "wasi:tls/client.[method]connector.send": trappable | tracing | store, default: trappable | tracing }, with: { - "wasi:tls/client.handshake": crate::p3::ClientHandshake, - "wasi:tls/client.hello": crate::p3::ClientHello, - "wasi:tls/server.handshake": crate::p3::ServerHandshake, - "wasi:tls/types.certificate": crate::p3::Certificate, + "wasi:tls/client.connector": crate::p3::Connector, + "wasi:tls/types.error": String, }, }); } diff --git a/crates/wasi-tls/src/p3/host/client.rs b/crates/wasi-tls/src/p3/host/client.rs index 440bd6dd5a67..6ce4a6ecbf14 100644 --- a/crates/wasi-tls/src/p3/host/client.rs +++ b/crates/wasi-tls/src/p3/host/client.rs @@ -1,259 +1,158 @@ use super::{ - CiphertextConsumer, CiphertextProducer, PlaintextConsumer, PlaintextProducer, ResultProducer, - mk_delete, mk_get, mk_get_mut, mk_push, + CiphertextConsumer, CiphertextProducer, Pending, PlaintextConsumer, PlaintextProducer, + mk_delete, mk_get_mut, mk_push, push_error, }; -use crate::p3::bindings::tls::client::{ - Handshake, Hello, Host, HostHandshake, HostHandshakeWithStore, HostHello, HostWithStore, -}; -use crate::p3::bindings::tls::types::Certificate; -use crate::p3::{TlsStream, TlsStreamClientArc, WasiTls, WasiTlsCtxView}; -use anyhow::{Context as _, anyhow, bail}; -use core::mem; -use core::net::{IpAddr, Ipv4Addr}; -use core::pin::{Pin, pin}; -use core::task::{Context, Poll}; -use rustls::client::ResolvesClientCert; -use rustls::pki_types::ServerName; +use crate::p3::bindings::tls::client::{Connector, Host, HostConnector, HostConnectorWithStore}; +use crate::p3::bindings::tls::types::Error; +use crate::p3::host::ResultProducer; +use crate::p3::{TlsStream, WasiTls, WasiTlsCtxView}; use std::sync::{Arc, Mutex}; use tokio::sync::oneshot; -use wasmtime::StoreContextMut; -use wasmtime::component::{Access, FutureProducer, FutureReader, Resource, StreamReader}; - -mk_push!(Hello, push_hello, "client hello"); -mk_get_mut!(Hello, get_hello_mut, "client hello"); -mk_delete!(Hello, delete_hello, "client hello"); - -mk_push!(Handshake, push_handshake, "client handshake"); -mk_get!(Handshake, get_handshake, "client handshake"); -mk_delete!(Handshake, delete_handshake, "client handshake"); - -#[derive(Default)] -enum ConnectProducer { - Pending { - stream: TlsStreamClientArc, - error_rx: oneshot::Receiver, - getter: fn(&mut T) -> WasiTlsCtxView<'_>, - }, - #[default] - Exhausted, -} - -impl FutureProducer for ConnectProducer -where - D: 'static, -{ - type Item = Result, ()>; - - fn poll_produce( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - mut store: StoreContextMut, - finish: bool, - ) -> Poll>> { - let this = self.get_mut(); - let Self::Pending { - stream, - mut error_rx, - getter, - } = mem::take(this) - else { - return Poll::Ready(Err(anyhow!("polled after ready"))); - }; - if let Poll::Ready(..) = pin!(&mut error_rx).poll(cx) { - return Poll::Ready(Ok(Some(Err(())))); - } - - { - let mut stream_lock = stream.lock(); - let TlsStream { conn, read_tls, .. } = stream_lock.as_deref_mut().unwrap(); - if conn.peer_certificates().is_none() || conn.negotiated_cipher_suite().is_none() { - if !finish { - *read_tls = Some(cx.waker().clone()); - } - drop(stream_lock); - *this = Self::Pending { - stream, - error_rx, - getter, - }; - if finish { - return Poll::Ready(Ok(None)); - } - return Poll::Pending; - } - }; - - let WasiTlsCtxView { table, .. } = getter(store.data_mut()); +use wasmtime::component::{Access, Accessor, FutureReader, Resource, StreamReader}; - let handshake = Handshake { stream, error_rx }; - let handshake = push_handshake(table, handshake)?; - - Poll::Ready(Ok(Some(Ok(handshake)))) - } -} - -#[derive(Debug)] -struct CertificateResolver; - -impl ResolvesClientCert for CertificateResolver { - fn resolve( - &self, - _root_hint_subjects: &[&[u8]], - _sigschemes: &[rustls::SignatureScheme], - ) -> Option> { - // TODO: implement - None - } - - fn has_certs(&self) -> bool { - false - } -} +mk_push!(Connector, push_connector, "client connector"); +mk_get_mut!(Connector, get_connector_mut, "client connector"); +mk_delete!(Connector, delete_connector, "client connector"); impl Host for WasiTlsCtxView<'_> {} -impl HostHello for WasiTlsCtxView<'_> { - fn new(&mut self) -> wasmtime::Result> { - push_hello(&mut self.table, Hello::default()) - } - - fn set_server_name( - &mut self, - hello: Resource, - server_name: String, - ) -> wasmtime::Result> { - let hello = get_hello_mut(&mut self.table, &hello)?; - let Ok(server_name) = server_name.try_into() else { - return Ok(Err(())); - }; - hello.server_name = Some(server_name); - Ok(Ok(())) - } - - fn set_alpn_ids( - &mut self, - hello: Resource, - alpn_ids: Vec>, - ) -> wasmtime::Result<()> { - let hello = get_hello_mut(&mut self.table, &hello)?; - hello.alpn_ids = Some(alpn_ids); - Ok(()) +impl HostConnector for WasiTlsCtxView<'_> { + fn new(&mut self) -> wasmtime::Result> { + push_connector(&mut self.table, Connector::default()) } - fn set_cipher_suites( - &mut self, - hello: Resource, - cipher_suites: Vec, - ) -> wasmtime::Result<()> { - let hello = get_hello_mut(&mut self.table, &hello)?; - hello.cipher_suites = cipher_suites; - Ok(()) - } - - fn drop(&mut self, hello: Resource) -> wasmtime::Result<()> { - delete_hello(&mut self.table, hello)?; + fn drop(&mut self, conn: Resource) -> wasmtime::Result<()> { + delete_connector(&mut self.table, conn)?; Ok(()) } } -impl HostWithStore for WasiTls { - fn connect( +impl HostConnectorWithStore for WasiTls { + fn send( mut store: Access, - hello: Resource, - incoming: StreamReader, - ) -> wasmtime::Result<( - StreamReader, - FutureReader, ()>>, - )> { - let Hello { - server_name, - alpn_ids, - cipher_suites, - } = delete_hello(store.get().table, hello)?; - - let roots = rustls::RootCertStore { - roots: webpki_roots::TLS_SERVER_ROOTS.into(), - }; - if !cipher_suites.is_empty() { - // TODO: implement - bail!("custom cipher suites not supported yet") - } - let mut config = rustls::ClientConfig::builder() - .with_root_certificates(roots) - .with_client_cert_resolver(Arc::new(CertificateResolver)); - if let Some(alpn_ids) = alpn_ids { - config.alpn_protocols = alpn_ids; - } - let server_name = if let Some(server_name) = server_name { - server_name - } else { - config.enable_sni = false; - ServerName::IpAddress(IpAddr::V4(Ipv4Addr::UNSPECIFIED).into()) + conn: Resource, + cleartext: StreamReader, + ) -> wasmtime::Result<(StreamReader, FutureReader>>)> + where + T: 'static, + { + let conn @ Connector { send_tx: None, .. } = get_connector_mut(store.get().table, &conn)? + else { + return Err(wasmtime::Error::msg("`send` was already called")); }; - let conn = rustls::ClientConnection::new(Arc::from(config), server_name) - .context("failed to construct rustls client connection")?; - let (error_tx, error_rx) = oneshot::channel(); - let stream = Arc::new(Mutex::new(TlsStream::new(conn, error_tx))); - incoming.pipe(&mut store, CiphertextConsumer(Arc::clone(&stream))); + let (cons_tx, cons_rx) = oneshot::channel(); + let (prod_tx, prod_rx) = oneshot::channel(); + let (err_tx, err_rx) = oneshot::channel(); + + conn.send_tx = Some((prod_tx, cons_tx, err_tx)); + + let rx = StreamReader::new(&mut store, Pending::from(prod_rx)); + cleartext.pipe(&mut store, Pending::from(cons_rx)); let getter = store.getter(); Ok(( - StreamReader::new(&mut store, CiphertextProducer(Arc::clone(&stream))), - FutureReader::new( - &mut store, - ConnectProducer::Pending { - stream, - error_rx, - getter, - }, - ), + rx, + FutureReader::new(store, ResultProducer { rx: err_rx, getter }), )) } -} -impl HostHandshake for WasiTlsCtxView<'_> { - fn set_client_certificate( - &mut self, - _handshake: Resource, - _cert: Resource, - ) -> wasmtime::Result<()> { - todo!() - } - - fn get_server_certificate( - &mut self, - _handshake: Resource, - ) -> wasmtime::Result>> { - todo!() - } + fn receive( + mut store: Access, + conn: Resource, + ciphertext: StreamReader, + ) -> wasmtime::Result<(StreamReader, FutureReader>>)> + where + T: 'static, + { + let conn @ Connector { + receive_tx: None, .. + } = get_connector_mut(store.get().table, &conn)? + else { + return Err(wasmtime::Error::msg("`receive` was already called")); + }; - fn get_cipher_suite(&mut self, handshake: Resource) -> wasmtime::Result { - let Handshake { stream, .. } = get_handshake(&self.table, &handshake)?; - let mut stream = stream.lock(); - let TlsStream { conn, .. } = stream.as_deref_mut().unwrap(); - let cipher_suite = conn - .negotiated_cipher_suite() - .context("cipher suite not available")?; - Ok(cipher_suite.suite().get_u16()) - } + let (cons_tx, cons_rx) = oneshot::channel(); + let (prod_tx, prod_rx) = oneshot::channel(); + let (err_tx, err_rx) = oneshot::channel(); - fn drop(&mut self, handshake: Resource) -> wasmtime::Result<()> { - delete_handshake(&mut self.table, handshake)?; - Ok(()) - } -} + conn.receive_tx = Some((prod_tx, cons_tx, err_tx)); -impl HostHandshakeWithStore for WasiTls { - fn finish( - mut store: Access, - handshake: Resource, - data: StreamReader, - ) -> wasmtime::Result<(StreamReader, FutureReader>)> { - let Handshake { stream, error_rx } = delete_handshake(&mut store.get().table, handshake)?; - data.pipe(&mut store, PlaintextConsumer(Arc::clone(&stream))); + let rx = StreamReader::new(&mut store, Pending::from(prod_rx)); + ciphertext.pipe(&mut store, Pending::from(cons_rx)); + let getter = store.getter(); Ok(( - StreamReader::new(&mut store, PlaintextProducer(stream)), - FutureReader::new(&mut store, ResultProducer(error_rx)), + rx, + FutureReader::new(store, ResultProducer { rx: err_rx, getter }), )) } + + async fn connect( + store: &Accessor, + conn: Resource, + server_name: String, + ) -> wasmtime::Result>> + where + T: 'static, + { + store.with(|mut store| { + let server_name = match server_name.try_into() { + Ok(name) => name, + Err(err) => { + let err = push_error(store.get().table, format!("{err}"))?; + return Ok(Err(err)); + } + }; + + let Connector { + receive_tx: Some((receive_prod_tx, receive_cons_tx, receive_err_tx)), + send_tx: Some((send_prod_tx, send_cons_tx, send_err_tx)), + } = delete_connector(store.get().table, conn)? + else { + let err = push_error( + store.get().table, + format!("`send` and `receive` must be called prior to calling `connect`"), + )?; + return Ok(Err(err)); + }; + + let roots = rustls::RootCertStore { + roots: webpki_roots::TLS_SERVER_ROOTS.into(), + }; + let config = rustls::ClientConfig::builder() + .with_root_certificates(roots) + .with_no_client_auth(); + + let conn = match rustls::ClientConnection::new(Arc::from(config), server_name) { + Ok(conn) => conn, + Err(err) => { + let err = push_error(store.get().table, format!("{err}"))?; + return Ok(Err(err)); + } + }; + + let stream = Arc::new(Mutex::new(TlsStream::new(conn))); + + let send_err_tx = Arc::new(Mutex::new(Some(send_err_tx))); + let _ = send_cons_tx.send(PlaintextConsumer { + stream: Arc::clone(&stream), + error_tx: Arc::clone(&send_err_tx), + }); + let _ = send_prod_tx.send(CiphertextProducer { + stream: Arc::clone(&stream), + error_tx: send_err_tx, + }); + + let receive_err_tx = Arc::new(Mutex::new(Some(receive_err_tx))); + let _ = receive_cons_tx.send(CiphertextConsumer { + stream: Arc::clone(&stream), + error_tx: Arc::clone(&receive_err_tx), + }); + let _ = receive_prod_tx.send(PlaintextProducer { + stream, + error_tx: receive_err_tx, + }); + + Ok(Ok(())) + }) + } } diff --git a/crates/wasi-tls/src/p3/host/mod.rs b/crates/wasi-tls/src/p3/host/mod.rs index 2b48c7f0b03e..32d656b227a4 100644 --- a/crates/wasi-tls/src/p3/host/mod.rs +++ b/crates/wasi-tls/src/p3/host/mod.rs @@ -1,17 +1,17 @@ -use crate::p3::{TlsStream, TlsStreamArc}; -use anyhow::Context as _; +use crate::p3::bindings::tls::client::Error; +use crate::p3::{TlsStream, TlsStreamArc, WasiTlsCtxView}; use core::ops::DerefMut; use core::pin::Pin; use core::task::{Context, Poll, Waker}; use std::io::{Read as _, Write as _}; +use std::sync::{Arc, Mutex}; use tokio::sync::oneshot; use wasmtime::StoreContextMut; use wasmtime::component::{ - Destination, FutureProducer, Source, StreamConsumer, StreamProducer, StreamResult, + Destination, FutureProducer, Resource, Source, StreamConsumer, StreamProducer, StreamResult, }; mod client; -mod server; mod types; macro_rules! mk_push { @@ -22,7 +22,7 @@ macro_rules! mk_push { table: &mut wasmtime::component::ResourceTable, value: $t, ) -> wasmtime::Result> { - use anyhow::Context as _; + use wasmtime::error::Context as _; table .push(value) @@ -39,7 +39,7 @@ macro_rules! mk_get { table: &'a wasmtime::component::ResourceTable, resource: &'a wasmtime::component::Resource<$t>, ) -> wasmtime::Result<&'a $t> { - use anyhow::Context as _; + use wasmtime::error::Context as _; table .get(resource) @@ -56,7 +56,7 @@ macro_rules! mk_get_mut { table: &'a mut wasmtime::component::ResourceTable, resource: &'a wasmtime::component::Resource<$t>, ) -> wasmtime::Result<&'a mut $t> { - use anyhow::Context as _; + use wasmtime::error::Context as _; table.get_mut(resource).context(concat!( "failed to get ", @@ -75,7 +75,7 @@ macro_rules! mk_delete { table: &mut wasmtime::component::ResourceTable, resource: wasmtime::component::Resource<$t>, ) -> wasmtime::Result<$t> { - use anyhow::Context as _; + use wasmtime::error::Context as _; table.delete(resource).context(concat!( "failed to delete ", @@ -88,7 +88,98 @@ macro_rules! mk_delete { pub(crate) use {mk_delete, mk_get, mk_get_mut, mk_push}; -struct CiphertextConsumer(TlsStreamArc); +mk_push!(Error, push_error, "error"); + +struct Pending { + inner_rx: oneshot::Receiver, + inner: Option, +} + +impl From> for Pending { + fn from(rx: oneshot::Receiver) -> Self { + Self { + inner_rx: rx, + inner: None, + } + } +} + +impl StreamProducer for Pending +where + T: StreamProducer + Unpin, +{ + type Item = >::Item; + type Buffer = >::Buffer; + + fn poll_produce<'a>( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + store: StoreContextMut<'a, D>, + dst: Destination<'a, Self::Item, Self::Buffer>, + finish: bool, + ) -> Poll> { + if let Some(ref mut inner) = self.inner { + return Pin::new(inner).poll_produce(cx, store, dst, finish); + } + match Pin::new(&mut self.inner_rx).poll(cx) { + Poll::Ready(Ok(inner)) => { + self.inner = Some(inner); + return self.poll_produce(cx, store, dst, finish); + } + Poll::Ready(Err(..)) => Poll::Ready(Ok(StreamResult::Dropped)), + Poll::Pending if finish => Poll::Ready(Ok(StreamResult::Cancelled)), + Poll::Pending => Poll::Pending, + } + } +} + +impl StreamConsumer for Pending +where + T: StreamConsumer + Unpin, +{ + type Item = >::Item; + + fn poll_consume( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + store: StoreContextMut, + src: Source, + finish: bool, + ) -> Poll> { + if let Some(ref mut inner) = self.inner { + return Pin::new(inner).poll_consume(cx, store, src, finish); + } + match Pin::new(&mut self.inner_rx).poll(cx) { + Poll::Ready(Ok(inner)) => { + self.inner = Some(inner); + return self.poll_consume(cx, store, src, finish); + } + Poll::Ready(Err(..)) => Poll::Ready(Ok(StreamResult::Dropped)), + Poll::Pending if finish => Poll::Ready(Ok(StreamResult::Cancelled)), + Poll::Pending => Poll::Pending, + } + } +} + +pub struct CiphertextConsumer { + stream: TlsStreamArc, + error_tx: Arc>>>, +} + +impl Drop for CiphertextConsumer { + fn drop(&mut self) { + let mut stream = self.stream.lock(); + let TlsStream { + ciphertext_consumer_dropped, + plaintext_producer, + ciphertext_producer, + .. + } = stream.as_deref_mut().unwrap(); + *ciphertext_consumer_dropped = true; + plaintext_producer.take().map(Waker::wake); + ciphertext_producer.take().map(Waker::wake); + } +} impl StreamConsumer for CiphertextConsumer where @@ -103,20 +194,20 @@ where src: Source, finish: bool, ) -> Poll> { - let mut stream = self.0.lock(); + let mut error_tx = self.error_tx.lock().unwrap(); + if error_tx.is_none() { + return Poll::Ready(Ok(StreamResult::Dropped)); + } + + let mut stream = self.stream.lock(); let TlsStream { conn, - error_tx, - read_tls, ciphertext_consumer, ciphertext_producer, plaintext_consumer, plaintext_producer, .. } = stream.as_deref_mut().unwrap(); - if error_tx.is_none() { - return Poll::Ready(Ok(StreamResult::Dropped)); - } if !conn.wants_read() { if finish { @@ -132,7 +223,6 @@ where } let n = conn.read_tls(&mut src)?; debug_assert_ne!(n, 0); - read_tls.take().map(Waker::wake); let state = match conn.process_new_packets() { Ok(state) => state, @@ -154,16 +244,17 @@ where } if state.peer_has_closed() { - // even if there are no bytes to read, the producer may be pending - plaintext_producer.take().map(Waker::wake); - return Poll::Ready(Ok(StreamResult::Dropped)); + Poll::Ready(Ok(StreamResult::Dropped)) + } else { + Poll::Ready(Ok(StreamResult::Completed)) } - - Poll::Ready(Ok(StreamResult::Completed)) } } -struct PlaintextProducer(TlsStreamArc); +pub struct PlaintextProducer { + stream: TlsStreamArc, + error_tx: Arc>>>, +} impl StreamProducer for PlaintextProducer where @@ -179,21 +270,35 @@ where dst: Destination<'a, Self::Item, Self::Buffer>, finish: bool, ) -> Poll> { - let mut stream = self.0.lock(); + let mut error_tx = self.error_tx.lock().unwrap(); + if error_tx.is_none() { + return Poll::Ready(Ok(StreamResult::Dropped)); + } + + let mut stream = self.stream.lock(); let TlsStream { conn, - error_tx, + ciphertext_consumer_dropped, ciphertext_consumer, + ciphertext_producer, + plaintext_consumer, plaintext_producer, .. } = stream.as_deref_mut().unwrap(); - if error_tx.is_none() { - return Poll::Ready(Ok(StreamResult::Dropped)); - } - let state = conn.process_new_packets().context("unhandled TLS error")?; + let state = match conn.process_new_packets() { + Ok(state) => state, + Err(err) => { + _ = error_tx.take().unwrap().send(err); + ciphertext_consumer.take().map(Waker::wake); + ciphertext_producer.take().map(Waker::wake); + plaintext_consumer.take().map(Waker::wake); + return Poll::Ready(Ok(StreamResult::Dropped)); + } + }; + if state.plaintext_bytes_to_read() == 0 { - if state.peer_has_closed() { + if state.peer_has_closed() || *ciphertext_consumer_dropped { return Poll::Ready(Ok(StreamResult::Dropped)); } else if finish { return Poll::Ready(Ok(StreamResult::Cancelled)); @@ -217,24 +322,26 @@ where } } -struct PlaintextConsumer(TlsStreamArc) +pub struct PlaintextConsumer where - T: DerefMut> + Send + 'static; + T: DerefMut> + Send + 'static, +{ + stream: TlsStreamArc, + error_tx: Arc>>>, +} impl Drop for PlaintextConsumer where T: DerefMut> + Send + 'static, { fn drop(&mut self) { - let mut stream = self.0.lock(); + let mut stream = self.stream.lock(); let TlsStream { - conn, - close_notify, + plaintext_consumer_dropped, ciphertext_producer, .. } = stream.as_deref_mut().unwrap(); - conn.send_close_notify(); - *close_notify = true; + *plaintext_consumer_dropped = true; ciphertext_producer.take().map(Waker::wake); } } @@ -253,17 +360,18 @@ where src: Source, finish: bool, ) -> Poll> { - let mut stream = self.0.lock(); + let error_tx = self.error_tx.lock().unwrap(); + if error_tx.is_none() { + return Poll::Ready(Ok(StreamResult::Dropped)); + } + + let mut stream = self.stream.lock(); let TlsStream { conn, - error_tx, ciphertext_producer, plaintext_consumer, .. } = stream.as_deref_mut().unwrap(); - if error_tx.is_none() { - return Poll::Ready(Ok(StreamResult::Dropped)); - } let mut src = src.as_direct(store); if src.remaining().is_empty() { @@ -288,7 +396,10 @@ where } } -struct CiphertextProducer(TlsStreamArc); +pub struct CiphertextProducer { + stream: TlsStreamArc, + error_tx: Arc>>>, +} impl StreamProducer for CiphertextProducer where @@ -304,22 +415,24 @@ where dst: Destination<'a, Self::Item, Self::Buffer>, finish: bool, ) -> Poll> { - let mut stream = self.0.lock(); + let mut error_tx = self.error_tx.lock().unwrap(); + if error_tx.is_none() { + return Poll::Ready(Ok(StreamResult::Dropped)); + } + + let mut stream = self.stream.lock(); let TlsStream { conn, - error_tx, - close_notify, + plaintext_consumer_dropped, + ciphertext_consumer_dropped, ciphertext_consumer, ciphertext_producer, plaintext_consumer, - .. + plaintext_producer, } = stream.as_deref_mut().unwrap(); - if error_tx.is_none() { - return Poll::Ready(Ok(StreamResult::Dropped)); - } if !conn.wants_write() { - if *close_notify { + if *plaintext_consumer_dropped && *ciphertext_consumer_dropped { return Poll::Ready(Ok(StreamResult::Dropped)); } else if finish { return Poll::Ready(Ok(StreamResult::Cancelled)); @@ -329,7 +442,17 @@ where return Poll::Pending; } - let state = conn.process_new_packets().context("unhandled TLS error")?; + let state = match conn.process_new_packets() { + Ok(state) => state, + Err(err) => { + _ = error_tx.take().unwrap().send(err); + ciphertext_consumer.take().map(Waker::wake); + plaintext_consumer.take().map(Waker::wake); + plaintext_producer.take().map(Waker::wake); + return Poll::Ready(Ok(StreamResult::Dropped)); + } + }; + let mut dst = dst.as_direct(store, state.tls_bytes_to_write()); if dst.remaining().is_empty() { return Poll::Ready(Ok(StreamResult::Completed)); @@ -343,19 +466,29 @@ where } } -struct ResultProducer(oneshot::Receiver); +pub struct ResultProducer { + rx: oneshot::Receiver, + getter: for<'a> fn(&'a mut T) -> WasiTlsCtxView<'a>, +} -impl FutureProducer for ResultProducer { - type Item = Result<(), ()>; +impl FutureProducer for ResultProducer +where + D: 'static, +{ + type Item = Result<(), Resource>; fn poll_produce( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - _store: StoreContextMut, + mut store: StoreContextMut, finish: bool, - ) -> Poll>> { - match Pin::new(&mut self.0).poll(cx) { - Poll::Ready(Ok(_err)) => Poll::Ready(Ok(Some(Err(())))), + ) -> Poll>> { + match Pin::new(&mut self.rx).poll(cx) { + Poll::Ready(Ok(err)) => { + let WasiTlsCtxView { table, .. } = (self.getter)(store.data_mut()); + let err = push_error(table, format!("{err}"))?; + Poll::Ready(Ok(Some(Err(err)))) + } Poll::Ready(Err(..)) => Poll::Ready(Ok(Some(Ok(())))), Poll::Pending if finish => Poll::Ready(Ok(None)), Poll::Pending => Poll::Pending, diff --git a/crates/wasi-tls/src/p3/host/server.rs b/crates/wasi-tls/src/p3/host/server.rs deleted file mode 100644 index 94ec6481a477..000000000000 --- a/crates/wasi-tls/src/p3/host/server.rs +++ /dev/null @@ -1,284 +0,0 @@ -#![expect(unused, reason = "WIP")] - -use super::{PlaintextConsumer, PlaintextProducer, ResultProducer, mk_delete, mk_get, mk_push}; -use crate::p3::bindings::tls::server::{ - Handshake, Host, HostHandshake, HostHandshakeWithStore, HostWithStore, -}; -use crate::p3::bindings::tls::types::Certificate; -use crate::p3::{TlsStream, TlsStreamServerArc, WasiTls, WasiTlsCtxView}; -use anyhow::{Context as _, anyhow}; -use core::mem; -use core::pin::Pin; -use core::task::{Context, Poll}; -use rustls::server::ResolvesServerCert; -use std::sync::{Arc, Mutex}; -use tokio::sync::oneshot; -use wasmtime::StoreContextMut; -use wasmtime::component::{ - Access, Accessor, Destination, FutureReader, Resource, Source, StreamConsumer, StreamProducer, - StreamReader, StreamResult, -}; - -mk_delete!(Handshake, delete_handshake, "server handshake"); -mk_get!(Handshake, get_handshake, "server handshake"); -mk_push!(Handshake, push_handshake, "server handshake"); - -enum CiphertextConsumer { - Pending { - acceptor: rustls::server::Acceptor, - tx: oneshot::Sender< - Result< - ( - rustls::server::Accepted, - oneshot::Sender, - ), - rustls::Error, - >, - >, - }, - Accepted(oneshot::Receiver), - Active(super::CiphertextConsumer), - Corrupted, -} - -impl StreamConsumer for CiphertextConsumer { - type Item = u8; - - fn poll_consume( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - store: StoreContextMut, - src: Source, - finish: bool, - ) -> Poll> { - let this = self.get_mut(); - match mem::replace(this, Self::Corrupted) { - Self::Pending { mut acceptor, tx } => { - let mut src = src.as_direct(store); - if src.remaining().is_empty() { - return Poll::Ready(Ok(StreamResult::Completed)); - } - acceptor.read_tls(&mut src)?; - match acceptor.accept() { - Ok(None) => { - *this = Self::Pending { acceptor, tx }; - Poll::Ready(Ok(StreamResult::Completed)) - } - Ok(Some(accepted)) => { - let (stream_tx, stream_rx) = oneshot::channel(); - _ = tx.send(Ok((accepted, stream_tx))); - *this = Self::Accepted(stream_rx); - Poll::Ready(Ok(StreamResult::Completed)) - } - Err(err) => { - _ = tx.send(Err(err)); - Poll::Ready(Ok(StreamResult::Dropped)) - } - } - } - Self::Accepted(mut rx) => match Pin::new(&mut rx).poll(cx) { - Poll::Ready(Ok(stream)) => { - *this = Self::Active(super::CiphertextConsumer(stream)); - Poll::Ready(Ok(StreamResult::Completed)) - } - Poll::Ready(Err(..)) => Poll::Ready(Ok(StreamResult::Dropped)), - Poll::Pending if finish => { - *this = Self::Accepted(rx); - Poll::Ready(Ok(StreamResult::Cancelled)) - } - Poll::Pending => { - *this = Self::Accepted(rx); - Poll::Ready(Ok(StreamResult::Cancelled)) - } - }, - Self::Active(ref mut conn) => Pin::new(conn).poll_consume(cx, store, src, finish), - Self::Corrupted => Poll::Ready(Err(anyhow!("corrupted stream consumer state"))), - } - } -} - -enum CiphertextProducer { - Pending(oneshot::Receiver), - Active(super::CiphertextProducer), - Corrupted, -} - -impl StreamProducer for CiphertextProducer { - type Item = u8; - type Buffer = Option; // unused - - fn poll_produce<'a>( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - store: StoreContextMut<'a, D>, - dst: Destination<'a, Self::Item, Self::Buffer>, - finish: bool, - ) -> Poll> { - let this = self.get_mut(); - match mem::replace(this, Self::Corrupted) { - Self::Pending(mut rx) => match Pin::new(&mut rx).poll(cx) { - Poll::Ready(Ok(stream)) => { - *this = Self::Active(super::CiphertextProducer(stream)); - Poll::Ready(Ok(StreamResult::Completed)) - } - Poll::Ready(Err(..)) => Poll::Ready(Ok(StreamResult::Dropped)), - Poll::Pending if finish => { - *this = Self::Pending(rx); - Poll::Ready(Ok(StreamResult::Cancelled)) - } - Poll::Pending => { - *this = Self::Pending(rx); - Poll::Ready(Ok(StreamResult::Cancelled)) - } - }, - Self::Active(ref mut conn) => Pin::new(conn).poll_produce(cx, store, dst, finish), - Self::Corrupted => Poll::Ready(Err(anyhow!("corrupted stream producer state"))), - } - } -} - -#[derive(Debug)] -struct CertificateResolver; - -impl ResolvesServerCert for CertificateResolver { - fn resolve( - &self, - hello: rustls::server::ClientHello, - ) -> Option> { - // TODO: Implement - None - } -} - -impl Host for WasiTlsCtxView<'_> {} - -impl HostWithStore for WasiTls { - async fn accept( - store: &Accessor, - incoming: StreamReader, - ) -> wasmtime::Result, Resource), ()>> { - let (accept_tx, accept_rx) = oneshot::channel(); - store.with(|store| { - incoming.pipe( - store, - CiphertextConsumer::Pending { - acceptor: rustls::server::Acceptor::default(), - tx: accept_tx, - }, - ); - }); - let (accepted, consumer_tx) = match accept_rx - .await - .context("oneshot sender dropped unexpectedly")? - { - Ok((accepted, consumer_tx)) => (accepted, consumer_tx), - Err(_err) => return Ok(Err(())), - }; - let (producer_tx, producer_rx) = oneshot::channel(); - store.with(|mut store| { - let handshake = push_handshake( - store.get().table, - Handshake { - accepted, - consumer_tx, - producer_tx, - }, - )?; - Ok(Ok(( - StreamReader::new(&mut store, CiphertextProducer::Pending(producer_rx)), - handshake, - ))) - }) - } -} - -impl HostHandshake for WasiTlsCtxView<'_> { - fn set_server_certificate( - &mut self, - handshake: Resource, - cert: Resource, - ) -> wasmtime::Result<()> { - todo!() - } - - fn get_client_certificate( - &mut self, - handshake: Resource, - ) -> wasmtime::Result, ()>>> { - todo!() - } - - fn get_server_name( - &mut self, - handshake: Resource, - ) -> wasmtime::Result> { - let handshake = get_handshake(&self.table, &handshake)?; - let hello = handshake.accepted.client_hello(); - let server_name = hello.server_name().map(Into::into); - Ok(server_name) - } - - fn get_alpn_ids( - &mut self, - handshake: Resource, - ) -> wasmtime::Result>>> { - let handshake = get_handshake(&self.table, &handshake)?; - let hello = handshake.accepted.client_hello(); - let alpn = hello.alpn().map(|alpn| alpn.map(Into::into).collect()); - Ok(alpn) - } - - fn get_cipher_suites(&mut self, handshake: Resource) -> wasmtime::Result> { - let handshake = get_handshake(&self.table, &handshake)?; - let hello = handshake.accepted.client_hello(); - let cipher_suites = hello - .cipher_suites() - .into_iter() - .map(rustls::CipherSuite::get_u16) - .collect(); - Ok(cipher_suites) - } - - fn set_cipher_suite( - &mut self, - handshake: Resource, - cipher_suite: u16, - ) -> wasmtime::Result<()> { - todo!() - } - - fn drop(&mut self, handshake: Resource) -> wasmtime::Result<()> { - delete_handshake(&mut self.table, handshake)?; - Ok(()) - } -} - -impl HostHandshakeWithStore for WasiTls { - fn finish( - mut store: Access, - handshake: Resource, - data: StreamReader, - ) -> wasmtime::Result<(StreamReader, FutureReader>)> { - let Handshake { - accepted, - consumer_tx, - producer_tx, - } = delete_handshake(&mut store.get().table, handshake)?; - // TODO: configure - let config = rustls::ServerConfig::builder() - .with_no_client_auth() - .with_cert_resolver(Arc::new(CertificateResolver)); - let conn = accepted - .into_connection(Arc::from(config)) - .context("failed to construct rustls server connection")?; - let (error_tx, error_rx) = oneshot::channel(); - let stream = Arc::new(Mutex::new(TlsStream::new(conn, error_tx))); - data.pipe(&mut store, PlaintextConsumer(Arc::clone(&stream))); - _ = consumer_tx.send(Arc::clone(&stream)); - _ = producer_tx.send(Arc::clone(&stream)); - Ok(( - StreamReader::new(&mut store, PlaintextProducer(stream)), - FutureReader::new(&mut store, ResultProducer(error_rx)), - )) - } -} diff --git a/crates/wasi-tls/src/p3/host/types.rs b/crates/wasi-tls/src/p3/host/types.rs index 2c71ef0fbb5e..98a11b69b41c 100644 --- a/crates/wasi-tls/src/p3/host/types.rs +++ b/crates/wasi-tls/src/p3/host/types.rs @@ -1,15 +1,21 @@ -use super::mk_delete; +use super::{mk_delete, mk_get}; use crate::p3::WasiTlsCtxView; -use crate::p3::bindings::tls::types::{Certificate, Host, HostCertificate}; +use crate::p3::bindings::tls::types::{Error, Host, HostError}; use wasmtime::component::Resource; -mk_delete!(Certificate, delete_certificate, "certificate"); +mk_get!(Error, get_error, "error"); +mk_delete!(Error, delete_error, "error"); impl Host for WasiTlsCtxView<'_> {} -impl HostCertificate for WasiTlsCtxView<'_> { - fn drop(&mut self, cert: Resource) -> wasmtime::Result<()> { - delete_certificate(&mut self.table, cert)?; +impl HostError for WasiTlsCtxView<'_> { + fn to_debug_string(&mut self, err: Resource) -> wasmtime::Result { + let err = get_error(self.table, &err)?; + Ok(err.clone()) + } + + fn drop(&mut self, err: Resource) -> wasmtime::Result<()> { + delete_error(&mut self.table, err)?; Ok(()) } } diff --git a/crates/wasi-tls/src/p3/mod.rs b/crates/wasi-tls/src/p3/mod.rs index e3be4924a4b4..fbab96748f45 100644 --- a/crates/wasi-tls/src/p3/mod.rs +++ b/crates/wasi-tls/src/p3/mod.rs @@ -11,11 +11,12 @@ pub mod bindings; mod host; +use crate::p3::host::{ + CiphertextConsumer, CiphertextProducer, PlaintextConsumer, PlaintextProducer, +}; +use bindings::tls::{client, types}; use core::task::Waker; use std::sync::{Arc, Mutex}; - -use bindings::tls::{client, server, types}; -use rustls::pki_types::ServerName; use tokio::sync::oneshot; use wasmtime::component::{HasData, Linker, ResourceTable}; @@ -65,7 +66,6 @@ pub trait WasiTlsView: Send { /// /// fn main() -> Result<()> { /// let mut config = Config::new(); -/// config.async_support(true); /// config.wasm_component_model_async(true); /// let engine = Engine::new(&config)?; /// @@ -103,61 +103,33 @@ where T: WasiTlsView + 'static, { client::add_to_linker::<_, WasiTls>(linker, T::tls)?; - server::add_to_linker::<_, WasiTls>(linker, T::tls)?; types::add_to_linker::<_, WasiTls>(linker, T::tls)?; Ok(()) } -/// Client hello -#[derive(Clone, Default, Eq, PartialEq, Hash)] -pub struct ClientHello { - /// Server name indicator. - pub server_name: Option>, - /// ALPN IDs - pub alpn_ids: Option>>, - /// Cipher suites - pub cipher_suites: Vec, -} - -/// Server hello -#[derive(Clone, Eq, PartialEq, Hash)] -pub struct ServerHello { - /// Cipher suite - pub cipher_suite: u16, -} - -impl ServerHello { - /// Constructs a new server hello message - pub fn new(cipher_suite: u16) -> Self { - Self { cipher_suite } - } +/// TLS client connector state. +#[derive(Default)] +pub struct Connector { + pub(crate) receive_tx: Option<( + oneshot::Sender>, + oneshot::Sender>, + oneshot::Sender, + )>, + pub(crate) send_tx: Option<( + oneshot::Sender>, + oneshot::Sender< + PlaintextConsumer, + >, + oneshot::Sender, + )>, } type TlsStreamArc = Arc>>; -type TlsStreamClientArc = TlsStreamArc; -type TlsStreamServerArc = TlsStreamArc; - -/// Client handshake -pub struct ClientHandshake { - stream: TlsStreamClientArc, - error_rx: oneshot::Receiver, -} - -/// Server handshake -pub struct ServerHandshake { - accepted: rustls::server::Accepted, - consumer_tx: oneshot::Sender, - producer_tx: oneshot::Sender, -} - -/// Certificate -pub struct Certificate; struct TlsStream { conn: T, - error_tx: Option>, - close_notify: bool, - read_tls: Option, + plaintext_consumer_dropped: bool, + ciphertext_consumer_dropped: bool, ciphertext_consumer: Option, ciphertext_producer: Option, plaintext_consumer: Option, @@ -165,12 +137,11 @@ struct TlsStream { } impl TlsStream { - fn new(conn: T, error_tx: oneshot::Sender) -> Self { + fn new(conn: T) -> Self { Self { conn, - error_tx: Some(error_tx), - close_notify: false, - read_tls: None, + plaintext_consumer_dropped: false, + ciphertext_consumer_dropped: false, plaintext_producer: None, plaintext_consumer: None, ciphertext_producer: None, diff --git a/crates/wasi-tls/src/p3/wit/deps/tls/client.wit b/crates/wasi-tls/src/p3/wit/deps/tls/client.wit index 2aaaf695fa54..8f282dbe8e36 100644 --- a/crates/wasi-tls/src/p3/wit/deps/tls/client.wit +++ b/crates/wasi-tls/src/p3/wit/deps/tls/client.wit @@ -1,38 +1,22 @@ interface client { - use types.{certificate}; + use types.{error}; - resource hello { - /// Constructs a new ClientHello message. + resource connector { constructor(); - /// Sets the server name indicator. - set-server-name: func(server-name: string) -> result; + /// Set up the encryption stream transform. + /// This takes an unprotected `cleartext` application data stream and + /// returns an encrypted data stream, ready to be sent out over the network. + /// Closing the `cleartext` stream will cause a `close_notify` packet to be emitted on the returned output stream. + send: func(cleartext: stream) -> tuple, future>>; - /// Sets the ALPN IDs advertised by the client. - set-alpn-ids: func(alpn-ids: list>); + /// Set up the decryption stream transform. + /// This takes an encrypted data stream, as received via e.g. the network, + /// and returns a decrypted application data stream. + receive: func(ciphertext: stream) -> tuple, future>>; - /// Sets a list of the symmetric cipher options supported by - /// the client, specifically the record protection algorithm - /// (including secret key length) and a hash to be used with HKDF, in - /// descending order of client preference. - /// - /// If this list is empty, the implementation must use a reasonable default. - set-cipher-suites: func(cipher-suites: list); + /// Perform the handshake. + /// The `send` & `receive` streams must be set up before calling this method. + connect: static async func(this: connector, server-name: string) -> result<_, error>; } - - resource handshake { - set-client-certificate: func(cert: certificate); - - get-server-certificate: func() -> option; - - /// Gets the single cipher suite selected by the server from - /// the list in ClientHello.cipher_suites. - get-cipher-suite: func() -> u16; - - /// Closing the `data` stream will trigger `close_notify`. - finish: static func(this: handshake, data: stream) -> tuple, future>; - } - - /// Initiate the client TLS handshake - connect: func(hello: hello, incoming: stream) -> tuple, future>>; } diff --git a/crates/wasi-tls/src/p3/wit/deps/tls/server.wit b/crates/wasi-tls/src/p3/wit/deps/tls/server.wit deleted file mode 100644 index 6d0f23d33b13..000000000000 --- a/crates/wasi-tls/src/p3/wit/deps/tls/server.wit +++ /dev/null @@ -1,36 +0,0 @@ -interface server { - use types.{certificate}; - - resource handshake { - set-server-certificate: func(cert: certificate); - - get-client-certificate: func() -> future>; - - /// Gets the server name indicator. - /// Returns `none` if the client did not supply a SNI. - get-server-name: func() -> option; - - /// Gets the ALPN IDs advertised by the client. - /// Returns `none` if the client did not include an ALPN extension. - get-alpn-ids: func() -> option>>; - - /// Gets a list of the symmetric cipher options supported by - /// the client, specifically the record protection algorithm - /// (including secret key length) and a hash to be used with HKDF, in - /// descending order of client preference. - get-cipher-suites: func() -> list; - - /// Selects the cipher-suite from - /// the list returned by `get-cipher-suites` - /// - /// If this is not called before `finish`, implementation - /// will select appropriate cipher suite. - set-cipher-suite: func(cipher-suite: u16); - - /// Closing the `data` stream will trigger `close_notify`. - finish: static func(this: handshake, data: stream) -> tuple, future>; - } - - /// Accept the client TLS handshake - accept: async func(incoming: stream) -> result, handshake>>; -} diff --git a/crates/wasi-tls/src/p3/wit/deps/tls/types.wit b/crates/wasi-tls/src/p3/wit/deps/tls/types.wit index a0bcecae7da7..fc6c4b103e16 100644 --- a/crates/wasi-tls/src/p3/wit/deps/tls/types.wit +++ b/crates/wasi-tls/src/p3/wit/deps/tls/types.wit @@ -1,5 +1,5 @@ interface types { - resource certificate { - // TODO: define + resource error { + to-debug-string: func() -> string; } } diff --git a/crates/wasi-tls/src/p3/wit/deps/tls/world.wit b/crates/wasi-tls/src/p3/wit/deps/tls/world.wit index f605ce358457..599d49688ad2 100644 --- a/crates/wasi-tls/src/p3/wit/deps/tls/world.wit +++ b/crates/wasi-tls/src/p3/wit/deps/tls/world.wit @@ -2,6 +2,5 @@ package wasi:tls@0.3.0-draft; world imports { import client; - import server; import types; } diff --git a/crates/wasi-tls/tests/p3.rs b/crates/wasi-tls/tests/p3.rs index a2b1184e51c5..e68c07df238c 100644 --- a/crates/wasi-tls/tests/p3.rs +++ b/crates/wasi-tls/tests/p3.rs @@ -1,6 +1,7 @@ #![cfg(feature = "p3")] use wasmtime::component::{Component, Linker, ResourceTable}; +use wasmtime::error::Context as _; use wasmtime::{Result, Store, format_err}; use wasmtime_wasi::p3::bindings::Command; use wasmtime_wasi::{WasiCtx, WasiCtxView, WasiView}; @@ -43,7 +44,6 @@ async fn run_test(path: &str) -> Result<()> { }; let engine = test_programs_artifacts::engine(|config| { - config.async_support(true); config.wasm_component_model_async(true); }); let mut store = Store::new(&engine, ctx); @@ -59,12 +59,15 @@ async fn run_test(path: &str) -> Result<()> { let command = Command::instantiate_async(&mut store, &component, &linker) .await .context("failed to instantiate `wasi:cli/command`")?; - store + let (res, task) = store .run_concurrent(async move |store| command.wasi_cli_run().call_run(store).await) .await .context("failed to call `wasi:cli/run#run`")? - .context("guest trapped")? - .map_err(|()| format_err!("`wasi:cli/run#run` failed")) + .context("guest trapped")?; + res.map_err(|()| format_err!("`wasi:cli/run#run` failed"))?; + store + .run_concurrent(async move |store| task.block(store).await) + .await } macro_rules! assert_test_exists { @@ -76,7 +79,7 @@ macro_rules! assert_test_exists { test_programs_artifacts::foreach_p3_tls!(assert_test_exists); -#[tokio::test(flavor = "multi_thread")] +#[test_log::test(tokio::test(flavor = "multi_thread"))] async fn p3_tls_sample_application() -> Result<()> { run_test(test_programs_artifacts::P3_TLS_SAMPLE_APPLICATION_COMPONENT).await } From 1461d3db442cb64717e55423d5109c46488d6cee Mon Sep 17 00:00:00 2001 From: Roman Volosatovs Date: Tue, 17 Feb 2026 13:39:31 +0100 Subject: [PATCH 3/5] adapt openssl tests to new test names Signed-off-by: Roman Volosatovs --- crates/wasi-tls-openssl/tests/main.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/wasi-tls-openssl/tests/main.rs b/crates/wasi-tls-openssl/tests/main.rs index d09c0939aacb..a1683d549017 100644 --- a/crates/wasi-tls-openssl/tests/main.rs +++ b/crates/wasi-tls-openssl/tests/main.rs @@ -60,9 +60,9 @@ macro_rules! assert_test_exists { }; } -test_programs_artifacts::foreach_tls!(assert_test_exists); +test_programs_artifacts::foreach_p2_tls!(assert_test_exists); #[tokio::test(flavor = "multi_thread")] -async fn tls_sample_application() -> Result<()> { - run_test(test_programs_artifacts::TLS_SAMPLE_APPLICATION_COMPONENT).await +async fn p2_tls_sample_application() -> Result<()> { + run_test(test_programs_artifacts::P2_TLS_SAMPLE_APPLICATION_COMPONENT).await } From 575ee54db5c5e6caebd84dfe5af042e01c88d832 Mon Sep 17 00:00:00 2001 From: Roman Volosatovs Date: Tue, 17 Feb 2026 15:46:55 +0100 Subject: [PATCH 4/5] add p3 TLS support to CLI Signed-off-by: Roman Volosatovs --- Cargo.toml | 1 + src/commands/run.rs | 23 ++++++++++++++++++++--- src/common.rs | 7 ++++++- 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 30c846c9a860..ca6366abd569 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -520,6 +520,7 @@ component-model-async = [ "component-model", "wasmtime-wasi?/p3", "wasmtime-wasi-http?/p3", + "wasmtime-wasi-tls?/p3", "dep:futures", ] rr = ["wasmtime/rr", "component-model", "wasmtime-cli-flags/rr", "run"] diff --git a/src/commands/run.rs b/src/commands/run.rs index d61c7fe62521..9af831e18a86 100644 --- a/src/commands/run.rs +++ b/src/commands/run.rs @@ -1092,6 +1092,10 @@ impl RunCommand { ctx.ctx().table, ) })?; + #[cfg(feature = "component-model-async")] + if self.run.common.wasi.p3.unwrap_or(crate::common::P3_DEFAULT) { + wasmtime_wasi_tls::p3::add_to_linker(linker)?; + } let ctx = wasmtime_wasi_tls::WasiTlsCtxBuilder::new().build(); store.data_mut().wasi_tls = Some(Arc::new(ctx)); @@ -1225,8 +1229,11 @@ pub struct Host { wasi_http_outgoing_body_buffer_chunks: Option, #[cfg(feature = "wasi-http")] wasi_http_outgoing_body_chunk_size: Option, - #[cfg(all(feature = "wasi-http", feature = "component-model-async"))] - p3_http: crate::common::DefaultP3Ctx, + #[cfg(all( + any(feature = "wasi-http", feature = "wasi-tls"), + feature = "component-model-async" + ))] + p3_ctx: crate::common::DefaultP3Ctx, limits: StoreLimits, #[cfg(feature = "profiling")] guest_profiler: Option>, @@ -1286,7 +1293,17 @@ impl wasmtime_wasi_http::p3::WasiHttpView for Host { fn http(&mut self) -> wasmtime_wasi_http::p3::WasiHttpCtxView<'_> { wasmtime_wasi_http::p3::WasiHttpCtxView { table: WasiView::ctx(unwrap_singlethread_context(&mut self.wasip1_ctx)).table, - ctx: &mut self.p3_http, + ctx: &mut self.p3_ctx, + } + } +} + +#[cfg(all(feature = "wasi-tls", feature = "component-model-async"))] +impl wasmtime_wasi_tls::p3::WasiTlsView for Host { + fn tls(&mut self) -> wasmtime_wasi_tls::p3::WasiTlsCtxView<'_> { + wasmtime_wasi_tls::p3::WasiTlsCtxView { + table: WasiView::ctx(unwrap_singlethread_context(&mut self.wasip1_ctx)).table, + ctx: &mut self.p3_ctx, } } } diff --git a/src/common.rs b/src/common.rs index 6f51c0f3e6d2..5448d5935198 100644 --- a/src/common.rs +++ b/src/common.rs @@ -433,7 +433,12 @@ impl Profile { } #[derive(Default, Clone)] -#[cfg(all(feature = "wasi-http", feature = "component-model-async"))] +#[cfg(all( + any(feature = "wasi-http", feature = "wasi-tls"), + feature = "component-model-async" +))] pub struct DefaultP3Ctx; #[cfg(all(feature = "wasi-http", feature = "component-model-async"))] impl wasmtime_wasi_http::p3::WasiHttpCtx for DefaultP3Ctx {} +#[cfg(all(feature = "wasi-tls", feature = "component-model-async"))] +impl wasmtime_wasi_tls::p3::WasiTlsCtx for DefaultP3Ctx {} From 19db6095b3fd230a5f328005c97886a71e0255b8 Mon Sep 17 00:00:00 2001 From: Roman Volosatovs Date: Thu, 19 Feb 2026 17:13:48 +0100 Subject: [PATCH 5/5] fetch `wasi-tls` p3 WIT Signed-off-by: Roman Volosatovs --- ci/vendor-wit.sh | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ci/vendor-wit.sh b/ci/vendor-wit.sh index 2dd378ae7833..1b3e6c7f99e5 100755 --- a/ci/vendor-wit.sh +++ b/ci/vendor-wit.sh @@ -16,21 +16,22 @@ get_github() { local repo=$1 local tag=$2 local path=$3 + local prefix=${4:-wit} rm -rf "$path" mkdir -p "$path" - cached_extracted_dir="$cache_dir/$repo-$tag" + cached_extracted_dir="$cache_dir/$prefix/$repo-$tag" if [[ ! -d $cached_extracted_dir ]]; then mkdir -p $cached_extracted_dir curl --retry 5 --retry-all-errors -sLO https://github.com/WebAssembly/$repo/archive/$tag.tar.gz tar xzf $tag.tar.gz --strip-components=1 -C $cached_extracted_dir rm $tag.tar.gz - rm -rf $cached_extracted_dir/wit/deps* + rm -rf $cached_extracted_dir/${prefix}/deps* fi - cp -r $cached_extracted_dir/wit/* $path + cp -r $cached_extracted_dir/${prefix}/* $path } p2=0.2.6 @@ -65,8 +66,7 @@ mkdir -p crates/wasi-tls/wit/deps wkg get --format wit --overwrite "wasi:io@$p2" -o "crates/wasi-tls/wit/deps/io.wit" get_github wasi-tls v0.2.0-draft+505fc98 crates/wasi-tls/wit/deps/tls -# TODO: Use the tag when released -#get_github wasi-tls v0.3.0-draft crates/wasi-tls/src/p3/wit/deps/tls +get_github wasi-tls v0.2.0-draft+6781ae2 crates/wasi-tls/src/p3/wit/deps/tls wit-0.3.0-draft rm -rf crates/wasi-config/wit/deps mkdir -p crates/wasi-config/wit/deps