From be2789904a33f93ca052824ee870f197f6a4334c Mon Sep 17 00:00:00 2001 From: shellrow Date: Tue, 13 Jan 2026 00:16:54 +0900 Subject: [PATCH 01/16] Add mux sub-crates --- Cargo.toml | 9 +- websock-mux-proto/Cargo.toml | 19 + websock-mux-proto/src/lib.rs | 7 + websock-mux-proto/src/stream.rs | 154 +++++++ websock-mux-proto/src/varint.rs | 237 +++++++++++ websock-mux/Cargo.toml | 40 ++ websock-mux/examples/echo-client-mux.rs | 126 ++++++ websock-mux/examples/echo-server-mux.rs | 108 +++++ websock-mux/src/lib.rs | 11 + websock-mux/src/tungstenite.rs | 1 + websock-mux/src/wasm.rs | 1 + websock-proto/src/error.rs | 6 + websock-tungstenite-mux/Cargo.toml | 30 ++ websock-tungstenite-mux/src/builder.rs | 104 +++++ websock-tungstenite-mux/src/client.rs | 106 +++++ websock-tungstenite-mux/src/lib.rs | 77 ++++ websock-tungstenite-mux/src/server.rs | 159 +++++++ websock-tungstenite-mux/src/session.rs | 525 ++++++++++++++++++++++++ websock-tungstenite-mux/src/tls/cert.rs | 96 +++++ websock-tungstenite-mux/src/tls/key.rs | 22 + websock-tungstenite-mux/src/tls/mod.rs | 259 ++++++++++++ websock-wasm-mux/Cargo.toml | 7 + websock-wasm-mux/src/lib.rs | 3 + 23 files changed, 2106 insertions(+), 1 deletion(-) create mode 100644 websock-mux-proto/Cargo.toml create mode 100644 websock-mux-proto/src/lib.rs create mode 100644 websock-mux-proto/src/stream.rs create mode 100644 websock-mux-proto/src/varint.rs create mode 100644 websock-mux/Cargo.toml create mode 100644 websock-mux/examples/echo-client-mux.rs create mode 100644 websock-mux/examples/echo-server-mux.rs create mode 100644 websock-mux/src/lib.rs create mode 100644 websock-mux/src/tungstenite.rs create mode 100644 websock-mux/src/wasm.rs create mode 100644 websock-tungstenite-mux/Cargo.toml create mode 100644 websock-tungstenite-mux/src/builder.rs create mode 100644 websock-tungstenite-mux/src/client.rs create mode 100644 websock-tungstenite-mux/src/lib.rs create mode 100644 websock-tungstenite-mux/src/server.rs create mode 100644 websock-tungstenite-mux/src/session.rs create mode 100644 websock-tungstenite-mux/src/tls/cert.rs create mode 100644 websock-tungstenite-mux/src/tls/key.rs create mode 100644 websock-tungstenite-mux/src/tls/mod.rs create mode 100644 websock-wasm-mux/Cargo.toml create mode 100644 websock-wasm-mux/src/lib.rs diff --git a/Cargo.toml b/Cargo.toml index d5d9ff6..734ab95 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,11 @@ members = [ "websock-proto", "websock-tungstenite", "websock-wasm", - "websock-wasm-demo" + "websock-wasm-demo", + "websock-mux", + "websock-tungstenite-mux", + "websock-wasm-mux", + "websock-mux-proto" ] [workspace.package] @@ -17,4 +21,7 @@ authors = ["shellrow "] websock-proto = { path = "websock-proto", version = "0.1.0" } websock-tungstenite = { path = "websock-tungstenite", version = "0.1.0" } websock-wasm = { path = "websock-wasm", version = "0.1.0" } +websock-mux-proto = { path = "websock-mux-proto", version = "0.1.0" } +websock-tungstenite-mux = { path = "websock-tungstenite-mux", version = "0.1.0" } +websock-wasm-mux = { path = "websock-wasm-mux", version = "0.1.0" } bytes = "1" diff --git a/websock-mux-proto/Cargo.toml b/websock-mux-proto/Cargo.toml new file mode 100644 index 0000000..5f18011 --- /dev/null +++ b/websock-mux-proto/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "websock-mux-proto" +version.workspace = true +edition.workspace = true +authors.workspace = true +description = "Protocol for multiplexing WebSocket logical streams" +repository = "https://github.com/foctal/websock" +readme = "../README.md" +keywords = ["network", "websocket", "multiplex"] +categories = ["network-programming", "web-programming"] +license = "MIT" + +[dependencies] +bytes = { workspace = true } +thiserror = "2" +websock-proto = { workspace = true } + +[target.'cfg(not(target_arch = "wasm32"))'.dependencies] +tokio = { version = "1", features = ["rt", "sync", "macros"] } diff --git a/websock-mux-proto/src/lib.rs b/websock-mux-proto/src/lib.rs new file mode 100644 index 0000000..94dbd2e --- /dev/null +++ b/websock-mux-proto/src/lib.rs @@ -0,0 +1,7 @@ +pub mod stream; +pub mod varint; + +pub use stream::{Frame, FrameDecodeError, StreamDir, StreamId}; +pub use varint::{VarInt, VarIntBoundsExceeded, VarIntUnexpectedEnd}; + +pub const SUBPROTOCOL: &str = "websock-mux/1"; diff --git a/websock-mux-proto/src/stream.rs b/websock-mux-proto/src/stream.rs new file mode 100644 index 0000000..bc6a236 --- /dev/null +++ b/websock-mux-proto/src/stream.rs @@ -0,0 +1,154 @@ +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +use crate::varint::{VarInt, VarIntBoundsExceeded, VarIntUnexpectedEnd}; + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub enum StreamDir { + Bi, + Uni, +} + +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +pub struct StreamId(pub u64); + +impl StreamId { + pub fn new(counter: u64, is_server: bool, dir: StreamDir) -> Result { + let initiator = if is_server { 1 } else { 0 }; + let dir_bit = match dir { + StreamDir::Bi => 0, + StreamDir::Uni => 1, + }; + let value = counter + .checked_shl(2) + .ok_or(VarIntBoundsExceeded)? + | ((dir_bit as u64) << 1) + | (initiator as u64); + VarInt::from_u64(value)?; + Ok(Self(value)) + } + + pub fn dir(self) -> StreamDir { + if (self.0 >> 1) & 1 == 1 { + StreamDir::Uni + } else { + StreamDir::Bi + } + } + + pub fn initiator_is_server(self) -> bool { + self.0 & 1 == 1 + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Frame { + OpenUni { id: StreamId }, + OpenBi { id: StreamId }, + Stream { id: StreamId, data: Bytes, fin: bool }, + ResetStream { id: StreamId, code: u64 }, + StopSending { id: StreamId, code: u64 }, + ConnectionClose { code: u64, reason: String }, +} + +impl Frame { + pub fn encode(&self) -> BytesMut { + let mut buf = BytesMut::new(); + match self { + Frame::OpenUni { id } => { + VarInt(0).encode(&mut buf); + VarInt(id.0).encode(&mut buf); + } + Frame::OpenBi { id } => { + VarInt(1).encode(&mut buf); + VarInt(id.0).encode(&mut buf); + } + Frame::Stream { id, data, fin } => { + VarInt(2).encode(&mut buf); + VarInt(id.0).encode(&mut buf); + VarInt(u64::from(*fin)).encode(&mut buf); + VarInt(data.len() as u64).encode(&mut buf); + buf.put_slice(data); + } + Frame::ResetStream { id, code } => { + VarInt(3).encode(&mut buf); + VarInt(id.0).encode(&mut buf); + VarInt(*code).encode(&mut buf); + } + Frame::StopSending { id, code } => { + VarInt(4).encode(&mut buf); + VarInt(id.0).encode(&mut buf); + VarInt(*code).encode(&mut buf); + } + Frame::ConnectionClose { code, reason } => { + VarInt(5).encode(&mut buf); + VarInt(*code).encode(&mut buf); + VarInt(reason.len() as u64).encode(&mut buf); + buf.put_slice(reason.as_bytes()); + } + } + buf + } + + pub fn decode(buf: &mut B) -> Result { + let tag = VarInt::decode(buf)?.into_inner(); + match tag { + 0 => Ok(Frame::OpenUni { + id: StreamId(VarInt::decode(buf)?.into_inner()), + }), + 1 => Ok(Frame::OpenBi { + id: StreamId(VarInt::decode(buf)?.into_inner()), + }), + 2 => { + let id = StreamId(VarInt::decode(buf)?.into_inner()); + let fin = VarInt::decode(buf)?.into_inner() != 0; + let len = VarInt::decode(buf)?.into_inner() as usize; + if buf.remaining() < len { + return Err(FrameDecodeError::UnexpectedEnd); + } + let mut data = vec![0u8; len]; + buf.copy_to_slice(&mut data); + Ok(Frame::Stream { + id, + data: Bytes::from(data), + fin, + }) + } + 3 => Ok(Frame::ResetStream { + id: StreamId(VarInt::decode(buf)?.into_inner()), + code: VarInt::decode(buf)?.into_inner(), + }), + 4 => Ok(Frame::StopSending { + id: StreamId(VarInt::decode(buf)?.into_inner()), + code: VarInt::decode(buf)?.into_inner(), + }), + 5 => { + let code = VarInt::decode(buf)?.into_inner(); + let len = VarInt::decode(buf)?.into_inner() as usize; + if buf.remaining() < len { + return Err(FrameDecodeError::UnexpectedEnd); + } + let mut data = vec![0u8; len]; + buf.copy_to_slice(&mut data); + let reason = String::from_utf8(data).map_err(|_| FrameDecodeError::InvalidUtf8)?; + Ok(Frame::ConnectionClose { code, reason }) + } + _ => Err(FrameDecodeError::UnknownTag(tag)), + } + } +} + +#[derive(Debug, thiserror::Error)] +pub enum FrameDecodeError { + #[error("unexpected end of buffer")] + UnexpectedEnd, + #[error("unknown frame tag {0}")] + UnknownTag(u64), + #[error("invalid utf-8 in reason")] + InvalidUtf8, +} + +impl From for FrameDecodeError { + fn from(_: VarIntUnexpectedEnd) -> Self { + FrameDecodeError::UnexpectedEnd + } +} diff --git a/websock-mux-proto/src/varint.rs b/websock-mux-proto/src/varint.rs new file mode 100644 index 0000000..9d294c7 --- /dev/null +++ b/websock-mux-proto/src/varint.rs @@ -0,0 +1,237 @@ +//! QUIC variable-length integer encoding and decoding. + +// Based on Quinn: https://github.com/quinn-rs/quinn/tree/main/quinn-proto/src +// Licensed under Apache-2.0 OR MIT + +use std::{convert::TryInto, fmt, io::Cursor}; + +use bytes::{Buf, BufMut}; +use thiserror::Error; +#[cfg(not(target_arch = "wasm32"))] +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +/// An integer less than 2^62. +/// +/// Values of this type are suitable for encoding as QUIC variable-length integer. +// Rust does not currently model that the top two bits are reserved for the length tag. +#[derive(Default, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] +pub struct VarInt(pub(crate) u64); + +impl VarInt { + /// The largest representable value. + pub const MAX: Self = Self((1 << 62) - 1); + /// The largest encoded value length. + pub const MAX_SIZE: usize = 8; + + /// Construct a `VarInt` infallibly. + pub const fn from_u32(x: u32) -> Self { + Self(x as u64) + } + + /// Succeeds if `x` < 2^62. + pub fn from_u64(x: u64) -> Result { + if x < 2u64.pow(62) { + Ok(Self(x)) + } else { + Err(VarIntBoundsExceeded) + } + } + + /// Create a `VarInt` without checking the bounds. + /// + /// # Safety + /// + /// `x` must be less than 2^62. + pub const unsafe fn from_u64_unchecked(x: u64) -> Self { + Self(x) + } + + /// Extract the integer value. + pub const fn into_inner(self) -> u64 { + self.0 + } + + /// Compute the number of bytes needed to encode this value. + pub fn size(self) -> usize { + let x = self.0; + if x < 2u64.pow(6) { + 1 + } else if x < 2u64.pow(14) { + 2 + } else if x < 2u64.pow(30) { + 4 + } else if x < 2u64.pow(62) { + 8 + } else { + unreachable!("malformed VarInt"); + } + } +} + +impl From for u64 { + fn from(x: VarInt) -> Self { + x.0 + } +} + +impl From for VarInt { + fn from(x: u8) -> Self { + Self(x.into()) + } +} + +impl From for VarInt { + fn from(x: u16) -> Self { + Self(x.into()) + } +} + +impl From for VarInt { + fn from(x: u32) -> Self { + Self(x.into()) + } +} + +impl std::convert::TryFrom for VarInt { + type Error = VarIntBoundsExceeded; + /// Succeeds if `x` < 2^62. + fn try_from(x: u64) -> Result { + Self::from_u64(x) + } +} + +impl std::convert::TryFrom for VarInt { + type Error = VarIntBoundsExceeded; + /// Succeeds if `x` < 2^62. + fn try_from(x: u128) -> Result { + Self::from_u64(x.try_into().map_err(|_| VarIntBoundsExceeded)?) + } +} + +impl std::convert::TryFrom for VarInt { + type Error = VarIntBoundsExceeded; + /// Succeeds if `x` < 2^62. + fn try_from(x: usize) -> Result { + Self::try_from(x as u64) + } +} + +impl fmt::Debug for VarInt { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +impl fmt::Display for VarInt { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +impl VarInt { + pub fn decode(r: &mut B) -> Result { + if !r.has_remaining() { + return Err(VarIntUnexpectedEnd); + } + let mut buf = [0; 8]; + buf[0] = r.get_u8(); + let tag = buf[0] >> 6; + buf[0] &= 0b0011_1111; + let x = match tag { + 0b00 => u64::from(buf[0]), + 0b01 => { + if r.remaining() < 1 { + return Err(VarIntUnexpectedEnd); + } + r.copy_to_slice(&mut buf[1..2]); + u64::from(u16::from_be_bytes(buf[..2].try_into().unwrap())) + } + 0b10 => { + if r.remaining() < 3 { + return Err(VarIntUnexpectedEnd); + } + r.copy_to_slice(&mut buf[1..4]); + u64::from(u32::from_be_bytes(buf[..4].try_into().unwrap())) + } + 0b11 => { + if r.remaining() < 7 { + return Err(VarIntUnexpectedEnd); + } + r.copy_to_slice(&mut buf[1..8]); + u64::from_be_bytes(buf) + } + _ => unreachable!(), + }; + Ok(Self(x)) + } + + // Read a varint from an async stream. + #[cfg(not(target_arch = "wasm32"))] + pub async fn read(stream: &mut S) -> Result { + // Eight bytes is the maximum encoded length. + let mut buf = [0; 8]; + + // Read the first byte because it encodes the length tag. + stream + .read_exact(&mut buf[0..1]) + .await + .map_err(|_| VarIntUnexpectedEnd)?; + + // 0b00 = 1 byte, 0b01 = 2 bytes, 0b10 = 4 bytes, 0b11 = 8 bytes. + let size = 1 << (buf[0] >> 6); + stream + .read_exact(&mut buf[1..size]) + .await + .map_err(|_| VarIntUnexpectedEnd)?; + + // Use a cursor to decode from the stack buffer. + let mut cursor = Cursor::new(&buf[..size]); + let v = VarInt::decode(&mut cursor).unwrap(); + + Ok(v) + } + + pub fn encode(&self, w: &mut B) { + let x = self.0; + if x < 2u64.pow(6) { + w.put_u8(x as u8); + } else if x < 2u64.pow(14) { + w.put_u16((0b01 << 14) | x as u16); + } else if x < 2u64.pow(30) { + w.put_u32((0b10 << 30) | x as u32); + } else if x < 2u64.pow(62) { + w.put_u64((0b11 << 62) | x); + } else { + unreachable!("malformed VarInt") + } + } + + #[cfg(not(target_arch = "wasm32"))] + pub async fn write( + &self, + stream: &mut S, + ) -> Result<(), VarIntUnexpectedEnd> { + // Keep the temporary buffer on the stack to avoid allocation. + let mut buf = [0u8; 8]; + let mut cursor: &mut [u8] = &mut buf; + self.encode(&mut cursor); + let size = 8 - cursor.len(); + + let mut cursor = &buf[..size]; + stream + .write_all_buf(&mut cursor) + .await + .map_err(|_| VarIntUnexpectedEnd)?; + + Ok(()) + } +} + +/// Error returned when constructing a `VarInt` from a value >= 2^62. +#[derive(Debug, Copy, Clone, Eq, PartialEq, Error)] +#[error("value too large for varint encoding")] +pub struct VarIntBoundsExceeded; + +#[derive(Error, Debug, Copy, Clone, Eq, PartialEq)] +#[error("unexpected end of buffer")] +pub struct VarIntUnexpectedEnd; diff --git a/websock-mux/Cargo.toml b/websock-mux/Cargo.toml new file mode 100644 index 0000000..898400f --- /dev/null +++ b/websock-mux/Cargo.toml @@ -0,0 +1,40 @@ +[package] +name = "websock-mux" +version.workspace = true +edition.workspace = true +authors.workspace = true +description = "WebSocket multiplexing layer for logical streams" +repository = "https://github.com/foctal/websock" +readme = "../README.md" +keywords = ["network", "websocket", "multiplex", "native", "wasm"] +categories = ["network-programming", "web-programming"] +license = "MIT" + +[dependencies] +websock-proto = { workspace = true } +websock-mux-proto = { workspace = true } + +[target.'cfg(not(target_arch = "wasm32"))'.dependencies] +websock-tungstenite-mux = { workspace = true } + +[target.'cfg(target_arch = "wasm32")'.dependencies] +websock-wasm-mux = { workspace = true } + +[dev-dependencies] +tokio = { version = "1", features = ["rt", "rt-multi-thread", "macros"] } +bytes = { workspace = true } +anyhow = "1" +tracing = "0.1" +tracing-subscriber = "0.3" +clap = { version = "4", features = ["derive"] } +url = "2" +rustls = { version = "0.23", default-features = false, features = ["ring", "std"] } +rustls-pemfile = "2.2" + +[[example]] +name = "echo-client-mux" +path = "examples/echo-client-mux.rs" + +[[example]] +name = "echo-server-mux" +path = "examples/echo-server-mux.rs" diff --git a/websock-mux/examples/echo-client-mux.rs b/websock-mux/examples/echo-client-mux.rs new file mode 100644 index 0000000..4f57cba --- /dev/null +++ b/websock-mux/examples/echo-client-mux.rs @@ -0,0 +1,126 @@ +//! Echo client example for the websock mux transport. +//! +//! This demonstrates opening a bidirectional stream and echoing bytes. + +use clap::Parser; +use rustls::{client::ClientConfig, RootCertStore}; +use std::{path, sync::Arc}; +use tracing::Level; +use tracing_subscriber::FmtSubscriber; +use url::Url; +use websock_mux::{ + tls::{self, TlsClientConfigBuilder}, + Client, ClientBuilder, +}; + +const DEFAULT_ECHO_URL: &str = "ws://127.0.0.1:9001"; + +/// Command-line arguments for the echo mux client. +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// WebSocket server URL (ws:// or wss://). + #[arg(short, long)] + url: Option, + + /// Accept the server certificates at this path, encoded as PEM. + #[arg(long)] + tls_cert: Option, + + /// Dangerous: Disable TLS certificate verification. + #[arg(long, default_value = "false")] + tls_disable_verify: bool, +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + // Initialize tracing subscriber for logging in this example. + let subscriber = FmtSubscriber::builder() + .with_max_level(Level::DEBUG) + .finish(); + + tracing::subscriber::set_global_default(subscriber) + .expect("failed to set global tracing subscriber"); + + let args = Args::parse(); + let url = args + .url + .clone() + .unwrap_or_else(|| Url::parse(DEFAULT_ECHO_URL).expect("default url")); + + let client: Client = build_client(&url, &args)?; + tracing::info!("connecting to {}", url); + + let session = client.connect(url.as_str()).await?; + tracing::info!("connected"); + + tracing::info!("opening bidirectional stream"); + let (send, mut recv) = session.open_bi().await?; + tracing::info!("opened bidirectional stream"); + + send.write_all(b"hello mux").await?; + send.finish().await?; + + while let Some(chunk) = recv.read_chunk(1024).await? { + tracing::info!("Received: {}", String::from_utf8_lossy(&chunk)); + } + Ok(()) +} + +/// Build a mux client based on the URL scheme and CLI options. +fn build_client(url: &Url, args: &Args) -> anyhow::Result { + // Plain WS requires no TLS configuration. + if url.scheme() == "ws" { + return Ok(ClientBuilder::new().build()); + } + + // From here: wss:// with TLS enabled. + let is_local = is_localhost_url(url); + + let tls_cfg: ClientConfig = if args.tls_disable_verify { + tracing::warn!("disabling TLS certificate verification"); + TlsClientConfigBuilder::new_insecure()? + .with_alpn_protocols(vec![b"h3".to_vec()]) + .build() + } else if let Some(path) = &args.tls_cert { + let certs = tls::cert::load_certs(path)?; + anyhow::ensure!(!certs.is_empty(), "could not find certificate"); + + let mut roots = RootCertStore::empty(); + for c in certs { + let _ = roots.add(c); + } + + let mut cfg = ClientConfig::builder() + .with_root_certificates(roots) + .with_no_client_auth(); + cfg.alpn_protocols = vec![b"h3".to_vec()]; + cfg + } else if is_local { + tracing::warn!( + "no --tls-cert provided and target looks local ({}); \ + using insecure mode for quick testing (equivalent to --tls-disable-verify)", + url + ); + TlsClientConfigBuilder::new_insecure()? + .with_alpn_protocols(vec![b"h3".to_vec()]) + .build() + } else { + TlsClientConfigBuilder::new_with_native_certs()? + .with_alpn_protocols(vec![b"h3".to_vec()]) + .build() + }; + + Ok(ClientBuilder::new() + .with_tls_config(Arc::new(tls_cfg)) + .build()) +} + +/// Determine whether a URL points to a loopback host. +fn is_localhost_url(url: &Url) -> bool { + match url.host_str() { + Some("localhost") => true, + Some(host) => host == "127.0.0.1" || host == "::1", + None => false, + } +} diff --git a/websock-mux/examples/echo-server-mux.rs b/websock-mux/examples/echo-server-mux.rs new file mode 100644 index 0000000..456323e --- /dev/null +++ b/websock-mux/examples/echo-server-mux.rs @@ -0,0 +1,108 @@ +//! Echo server example for the websock mux transport. +//! +//! This server accepts mux sessions and echoes bytes over each bidirectional stream. + +use clap::Parser; +use std::{io, fs, path}; +use tracing::Level; +use tracing_subscriber::FmtSubscriber; +use websock_mux::{Server, ServerBuilder}; + +/// Command-line arguments for the echo mux server. +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Bind address for the server. + #[arg(short, long, default_value = "127.0.0.1:9001")] + addr: std::net::SocketAddr, + + /// Use the certificates at this path, encoded as PEM (enables wss://). + #[arg(long)] + tls_cert: Option, + + /// Use the private key at this path, encoded as PEM (enables wss://). + #[arg(long)] + tls_key: Option, +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + // Initialize tracing subscriber for logging in this example. + let subscriber = FmtSubscriber::builder() + .with_max_level(Level::DEBUG) + .finish(); + tracing::subscriber::set_global_default(subscriber) + .expect("failed to set global tracing subscriber"); + + let args = Args::parse(); + + let mut builder = ServerBuilder::new(); + + match (&args.tls_cert, &args.tls_key) { + (Some(cert_path), Some(key_path)) => { + let (chain, key) = load_pem_cert_and_key(cert_path, key_path)?; + let mut cfg = rustls::ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(chain, key) + .map_err(|e| anyhow::anyhow!(e.to_string()))?; + cfg.alpn_protocols = vec![b"h3".to_vec()]; + builder = builder.with_tls_config(cfg); + tracing::info!("TLS enabled (wss://)"); + } + (None, None) => { + tracing::warn!( + "TLS disabled (ws://). Provide --tls-cert/--tls-key to enable wss://" + ); + } + _ => anyhow::bail!("both --tls-cert and --tls-key must be provided together, or neither"), + } + + let server: Server = builder.build(args.addr).await?; + let scheme = if args.tls_cert.is_some() { "wss" } else { "ws" }; + tracing::info!("listening on {} ({}://{})", args.addr, scheme, args.addr); + + loop { + let session = server.accept().await?; + tracing::info!("accepted connection"); + tokio::spawn(async move { + loop { + let (send, mut recv) = match session.accept_bi().await { + Ok(streams) => streams, + Err(_) => break, + }; + tracing::info!("accepted bidirectional stream"); + tokio::spawn(async move { + let send = send; + while let Ok(Some(chunk)) = recv.read_chunk(1024).await { + let chunk_len = chunk.len(); + tracing::info!("Received chunk of size: {}", chunk_len); + if send.write_buf(chunk).await.is_err() { + break; + } + tracing::info!("Sent chunk of size: {}", chunk_len); + } + let _ = send.finish().await; + }); + } + }); + } +} + +/// Load a PEM-encoded certificate chain and private key from disk. +fn load_pem_cert_and_key( + cert_path: &path::Path, + key_path: &path::Path, +) -> anyhow::Result<(Vec>, rustls::pki_types::PrivateKeyDer<'static>)> { + let chain_file = fs::File::open(cert_path)?; + let mut chain_reader = io::BufReader::new(chain_file); + let chain: Vec> = + rustls_pemfile::certs(&mut chain_reader).collect::>()?; + anyhow::ensure!(!chain.is_empty(), "could not find certificate"); + + let key_file = fs::File::open(key_path)?; + let mut key_reader = io::BufReader::new(key_file); + let key = rustls_pemfile::private_key(&mut key_reader)? + .ok_or_else(|| anyhow::anyhow!("missing private key"))?; + + Ok((chain, key)) +} diff --git a/websock-mux/src/lib.rs b/websock-mux/src/lib.rs new file mode 100644 index 0000000..8de2e42 --- /dev/null +++ b/websock-mux/src/lib.rs @@ -0,0 +1,11 @@ +pub use websock_proto::*; + +#[cfg(any(not(target_arch = "wasm32"), target_os = "wasi"))] +#[path = "tungstenite.rs"] +mod websocket; + +#[cfg(all(target_arch = "wasm32", not(target_os = "wasi")))] +#[path = "wasm.rs"] +mod websocket; + +pub use websocket::*; diff --git a/websock-mux/src/tungstenite.rs b/websock-mux/src/tungstenite.rs new file mode 100644 index 0000000..c47354d --- /dev/null +++ b/websock-mux/src/tungstenite.rs @@ -0,0 +1 @@ +pub use websock_tungstenite_mux::*; diff --git a/websock-mux/src/wasm.rs b/websock-mux/src/wasm.rs new file mode 100644 index 0000000..4613d1e --- /dev/null +++ b/websock-mux/src/wasm.rs @@ -0,0 +1 @@ +pub use websock_wasm_mux::*; diff --git a/websock-proto/src/error.rs b/websock-proto/src/error.rs index 78c4a7f..d54b9e6 100644 --- a/websock-proto/src/error.rs +++ b/websock-proto/src/error.rs @@ -30,6 +30,12 @@ pub enum Error { #[error("unsupported: {0}")] Unsupported(String), + #[error("frame decode error: {0}")] + FrameDecode(String), + + #[error("stream id error: {0}")] + StreamId(String), + /// A catch-all error for unexpected failures. #[error("other error: {0}")] Other(String), diff --git a/websock-tungstenite-mux/Cargo.toml b/websock-tungstenite-mux/Cargo.toml new file mode 100644 index 0000000..6be528c --- /dev/null +++ b/websock-tungstenite-mux/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "websock-tungstenite-mux" +version.workspace = true +edition.workspace = true +authors.workspace = true +description = "Native WebSocket multiplexing layer for logical streams" +repository = "https://github.com/foctal/websock" +readme = "../README.md" +keywords = ["network", "websocket", "multiplex", "native", "wasm"] +categories = ["network-programming", "web-programming"] +license = "MIT" + +[dependencies] +bytes = { workspace = true } +thiserror = "2" +websock-proto = { workspace = true } +websock-mux-proto = { workspace = true } +tokio = { version = "1", features = ["rt", "sync", "macros"] } +tokio-rustls = { version = "0.26", default-features = false, features = ["ring"]} +rustls = { version = "0.23", default-features = false, features = ["ring", "std"] } +rustls-native-certs = "0.8" +rcgen = "0.14" +rustls-pemfile = "2.2" +tokio-tungstenite = { version = "0.28", features = ["rustls-tls-webpki-roots"] } +futures-util = { version = "0.3" } +futures-core = { version = "0.3" } +futures-sink = { version = "0.3" } +tokio-util = { version = "0.7" } +#http = "1" +tracing = "0.1" diff --git a/websock-tungstenite-mux/src/builder.rs b/websock-tungstenite-mux/src/builder.rs new file mode 100644 index 0000000..87eeaa6 --- /dev/null +++ b/websock-tungstenite-mux/src/builder.rs @@ -0,0 +1,104 @@ +use websock_proto::Result; +use std::sync::Arc; +use rustls::{ClientConfig, ServerConfig}; + +use crate::{bind, Client, Server}; +use websock_mux_proto::SUBPROTOCOL; + +/// Builder for a mux WebSocket client. +#[derive(Debug, Clone)] +pub struct ClientBuilder { + pub(crate) opts: websock_proto::ConnectOptions, + pub(crate) tls: Option>, +} + +impl Default for ClientBuilder { + fn default() -> Self { + Self::new() + } +} + +impl ClientBuilder { + pub fn new() -> Self { + let mut opts = websock_proto::ConnectOptions::default(); + opts.protocols.insert(0, SUBPROTOCOL.to_string()); + Self { opts, tls: None } + } + + pub fn with_options(mut self, opts: websock_proto::ConnectOptions) -> Self { + self.opts = opts; + if !self.opts.protocols.iter().any(|p| p == SUBPROTOCOL) { + self.opts.protocols.insert(0, SUBPROTOCOL.to_string()); + } + self + } + + /// Configure a custom rustls client config (for `wss://`). + pub fn with_tls_config(mut self, tls: Arc) -> Self { + self.tls = Some(tls); + self + } + + pub fn with_header(mut self, name: impl Into, value: impl Into) -> Self { + self.opts.headers.push((name.into(), value.into())); + self + } + + pub fn with_protocol(mut self, protocol: impl Into) -> Self { + self.opts.protocols.push(protocol.into()); + self + } + + pub fn build(self) -> Client { + Client { + opts: self.opts, + tls: self.tls, + } + } +} + +/// Builder for a mux WebSocket server. +#[derive(Debug, Clone)] +pub struct ServerBuilder { + pub(crate) opts: websock_proto::ServerOptions, + pub(crate) tls: Option, +} + +impl ServerBuilder { + pub fn new() -> Self { + let mut opts = websock_proto::ServerOptions::default(); + opts.protocols.push(SUBPROTOCOL.to_string()); + Self { opts, tls: None } + } + + pub fn with_options(mut self, opts: websock_proto::ServerOptions) -> Self { + self.opts = opts; + if !self.opts.protocols.iter().any(|p| p == SUBPROTOCOL) { + self.opts.protocols.insert(0, SUBPROTOCOL.to_string()); + } + self + } + + /// Configure TLS for incoming connections (accept `wss://`). + pub fn with_tls_config(mut self, tls: ServerConfig) -> Self { + self.tls = Some(tls); + self + } + + pub fn with_header(mut self, name: impl Into, value: impl Into) -> Self { + self.opts.headers.push((name.into(), value.into())); + self + } + + pub fn with_protocol(mut self, protocol: impl Into) -> Self { + self.opts.protocols.push(protocol.into()); + self + } + + pub async fn build(self, addr: A) -> Result + where + A: tokio::net::ToSocketAddrs, + { + bind(addr, self.opts, self.tls).await + } +} diff --git a/websock-tungstenite-mux/src/client.rs b/websock-tungstenite-mux/src/client.rs new file mode 100644 index 0000000..62c8e60 --- /dev/null +++ b/websock-tungstenite-mux/src/client.rs @@ -0,0 +1,106 @@ +use tokio_tungstenite::tungstenite; +use tungstenite::client::IntoClientRequest; +use tungstenite::http; +use tungstenite::http::header::SEC_WEBSOCKET_PROTOCOL; + +use std::sync::Arc; + +use rustls::ClientConfig; +use tokio_tungstenite::Connector; +use websock_proto::{Error, Result}; + +use crate::session::map_tungstenite_err; +use crate::Session; +use websock_mux_proto::SUBPROTOCOL; + +fn negotiated_protocol(resp: &http::Response>>) -> Option<&str> { + resp.headers() + .get(SEC_WEBSOCKET_PROTOCOL)? + .to_str() + .ok() + .map(|s| s.trim()) +} + +fn is_protocol_ok(p: &str) -> bool { + p.eq_ignore_ascii_case(SUBPROTOCOL) +} + +fn validate_client_protocols(opts: &websock_proto::ConnectOptions) -> Result<()> { + for p in &opts.protocols { + tungstenite::http::HeaderValue::from_str(p) + .map_err(|e| Error::Protocol(format!("invalid protocol value: {e}")))?; + } + Ok(()) +} + +#[derive(Debug, Clone)] +pub struct Client { + pub(crate) opts: websock_proto::ConnectOptions, + pub(crate) tls: Option>, +} + +impl Client { + pub fn options(&self) -> &websock_proto::ConnectOptions { + &self.opts + } + + /// Return the configured TLS client config (if any). + pub fn tls_config(&self) -> Option<&Arc> { + self.tls.as_ref() + } + + pub async fn connect(&self, url: &str) -> Result { + self.connect_with_tls(url, self.tls.clone()).await + } + + /// Connect with an explicit TLS configuration. + /// + /// When `tls` is `None`, the default Tokio Tungstenite TLS settings are used + /// (and plain `ws://` works as-is). + pub async fn connect_with_tls( + &self, + url: &str, + tls: Option>, + ) -> Result { + validate_client_protocols(&self.opts)?; + + let mut request = url + .into_client_request() + .map_err(|e| websock_proto::Error::InvalidUrl(e.to_string()))?; + + let headers = request.headers_mut(); + for (k, v) in self.opts.headers.iter() { + let name = tungstenite::http::header::HeaderName::from_bytes(k.as_bytes()) + .map_err(|e| Error::Protocol(format!("invalid header name: {e}")))?; + let value = tungstenite::http::header::HeaderValue::from_str(&v) + .map_err(|e| Error::Protocol(format!("invalid header value: {e}")))?; + headers.append(name, value); + } + + // Apply subprotocols. + if !self.opts.protocols.is_empty() { + let joined = self.opts.protocols.join(","); + let value = tungstenite::http::header::HeaderValue::from_str(&joined) + .map_err(|e| Error::Protocol(format!("invalid protocol value: {e}")))?; + headers.insert(SEC_WEBSOCKET_PROTOCOL, value); + } + + let connector = tls.map(Connector::Rustls); + let (stream, response) = tokio_tungstenite::connect_async_tls_with_config( + request, + None, + false, + connector, + ) + .await + .map_err(map_tungstenite_err)?; + + let proto = negotiated_protocol(&response) + .ok_or_else(|| Error::Protocol("missing SEC_WEBSOCKET_PROTOCOL in response".into()))?; + if !is_protocol_ok(proto) { + return Err(Error::Protocol(format!("subprotocol mismatch: {proto}"))); + } + + Session::new(stream, false) + } +} diff --git a/websock-tungstenite-mux/src/lib.rs b/websock-tungstenite-mux/src/lib.rs new file mode 100644 index 0000000..70048df --- /dev/null +++ b/websock-tungstenite-mux/src/lib.rs @@ -0,0 +1,77 @@ +//! Tokio + tokio-tungstenite based WebSocket multiplexing transport. +//! +//! This crate provides a QUIC/WebTransport-like logical stream interface over a single WebSocket. +//! The wire format (frames, StreamId, VarInt, etc.) lives in `websock-mux-proto`. + +mod builder; +mod client; +mod server; +mod session; +pub mod tls; + +pub use builder::{ClientBuilder, ServerBuilder}; +pub use client::Client; +pub use server::{bind, Server}; +pub use session::{RecvStream, SendStream, Session}; +pub use tls::{ + TlsClientConfig, TlsClientConfigBuilder, TlsConfig, TlsServerConfig, TlsServerConfigBuilder, +}; + +#[cfg(test)] +mod tests { + use bytes::{Bytes, BytesMut}; + use tokio::sync::mpsc; + use websock_mux_proto::{Frame, StreamDir, StreamId}; + + use crate::session::SessionInner; + + #[test] + fn frame_roundtrip() { + let id = StreamId(4); + let frame = Frame::Stream { + id, + data: Bytes::from_static(b"hello"), + fin: true, + }; + let mut buf = frame.encode(); + let decoded = Frame::decode(&mut buf).expect("decode"); + assert_eq!(frame, decoded); + } + + #[tokio::test] + async fn stream_open_data_fin() { + let (outbound_tx, _outbound_rx) = mpsc::channel(4); + let (accept_uni_tx, mut accept_uni_rx) = mpsc::channel(4); + let (accept_bi_tx, _accept_bi_rx) = mpsc::channel(4); + let inner = std::sync::Arc::new(SessionInner::new( + false, + outbound_tx, + accept_uni_tx, + accept_bi_tx, + )); + let id = StreamId::new(0, true, StreamDir::Uni).expect("stream id"); + inner + .clone() + .handle_frame(Frame::OpenUni { id }) + .await + .expect("open"); + let mut recv = accept_uni_rx.recv().await.expect("recv stream"); + let data = Bytes::from_static(b"ping"); + inner + .clone() + .handle_frame(Frame::Stream { + id, + data: data.clone(), + fin: true, + }) + .await + .expect("stream"); + let mut buf = BytesMut::new(); + let n = recv.read_buf::(&mut buf).await.expect("read"); + assert_eq!(n, Some(4)); + assert_eq!(buf.as_ref(), data.as_ref()); + let max_size: usize = 1024; + let end = recv.read_chunk(max_size).await.expect("fin"); + assert!(end.is_none()); + } +} diff --git a/websock-tungstenite-mux/src/server.rs b/websock-tungstenite-mux/src/server.rs new file mode 100644 index 0000000..526073d --- /dev/null +++ b/websock-tungstenite-mux/src/server.rs @@ -0,0 +1,159 @@ +use std::sync::Arc; +use tokio_tungstenite::tungstenite; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::net::{TcpListener, ToSocketAddrs}; +use tokio_rustls::TlsAcceptor; +use tokio_tungstenite::accept_hdr_async; +use tokio_tungstenite::tungstenite::handshake::server; +use tungstenite::http; +use tungstenite::http::header::{HeaderName, HeaderValue, SEC_WEBSOCKET_PROTOCOL}; + +use websock_proto::{Error, Result}; + +use crate::session::map_tungstenite_err; +use crate::Session; +use websock_mux_proto::SUBPROTOCOL; + +/// Marker trait for IO types usable by the server. +pub trait ServerIo: AsyncRead + AsyncWrite + Unpin + Send {} + +impl ServerIo for T where T: AsyncRead + AsyncWrite + Unpin + Send {} + +/// Boxed stream type used for server connections. +pub type ServerStream = Box; + +/// Convert configured headers into tungstenite types. +fn prepare_headers(opts: &websock_proto::ServerOptions) -> Result> { + let mut out = Vec::new(); + for (k, v) in &opts.headers { + let name = HeaderName::from_bytes(k.as_bytes()) + .map_err(|e| Error::Protocol(format!("invalid header name: {e}")))?; + let value = HeaderValue::from_str(v) + .map_err(|e| Error::Protocol(format!("invalid header value: {e}")))?; + out.push((name, value)); + } + Ok(out) +} + +/// Validate configured subprotocols before binding. +fn validate_protocols(opts: &websock_proto::ServerOptions) -> Result<()> { + for protocol in &opts.protocols { + HeaderValue::from_str(protocol) + .map_err(|e| Error::Protocol(format!("invalid protocol value: {e}")))?; + } + Ok(()) +} + +/// Select the first requested subprotocol that appears in the allowed list. +fn select_protocol(req: &server::Request, allowed: &[String]) -> Option { + if allowed.is_empty() { + return None; + } + let header = req.headers().get(SEC_WEBSOCKET_PROTOCOL)?; + let header = header.to_str().ok()?; + for candidate in header.split(',').map(|s| s.trim()) { + if allowed.iter().any(|p| p == candidate) { + return Some(candidate.to_string()); + } + } + None +} + +/// Bind a WebSocket server listener. +pub async fn bind( + addr: A, + opts: websock_proto::ServerOptions, + tls: Option, +) -> Result +where + A: ToSocketAddrs, +{ + let listener = TcpListener::bind(addr) + .await + .map_err(|e| Error::Io(e.to_string()))?; + let headers = prepare_headers(&opts)?; + validate_protocols(&opts)?; + + if !opts.protocols.iter().any(|p| p == SUBPROTOCOL) { + return Err(Error::Protocol("SUBPROTOCOL missing in ServerOptions".into())); + } + + let acceptor = tls.map(|cfg| TlsAcceptor::from(Arc::new(cfg))); + + Ok(Server { + listener, + opts, + headers, + acceptor, + }) +} + +pub struct Server { + listener: TcpListener, + opts: websock_proto::ServerOptions, + headers: Vec<(HeaderName, HeaderValue)>, + acceptor: Option, +} + +impl Server { + pub async fn accept(&self) -> Result { + let (stream, _) = self + .listener + .accept() + .await + .map_err(|e| Error::Io(e.to_string()))?; + + let (stream, _is_tls): (ServerStream, bool) = if let Some(acceptor) = &self.acceptor { + let tls_stream = acceptor + .accept(stream) + .await + .map_err(|e| Error::Tls(e.to_string()))?; + (Box::new(tls_stream), true) + } else { + (Box::new(stream), false) + }; + + let headers = self.headers.clone(); + let allowed = self.opts.protocols.clone(); + + let ws = accept_hdr_async(stream, move |req: &server::Request, mut resp: server::Response| { + // Additional headers from configuration + for (name, value) in &headers { + resp.headers_mut().append(name, value.clone()); + } + + // websock-mux is required protocol + let Some(protocol) = select_protocol(req, &allowed) else { + return Err(http::Response::builder() + .status(http::StatusCode::BAD_REQUEST) + .body(Some(format!("'{SUBPROTOCOL}' protocol required"))) + .unwrap()); + }; + + // Confirm that the required protocol is present + if !protocol.eq_ignore_ascii_case(SUBPROTOCOL) { + return Err(http::Response::builder() + .status(http::StatusCode::BAD_REQUEST) + .body(Some(format!("'{SUBPROTOCOL}' protocol required"))) + .unwrap()); + } + + resp.headers_mut().insert( + http::header::SEC_WEBSOCKET_PROTOCOL, + http::HeaderValue::from_str(&protocol).expect("validated"), + ); + + Ok(resp) + }) + .await + .map_err(map_tungstenite_err)?; + + Session::new(ws, true) + } + + pub fn local_addr(&self) -> Result { + self.listener + .local_addr() + .map_err(|e| Error::Io(e.to_string())) + } +} diff --git a/websock-tungstenite-mux/src/session.rs b/websock-tungstenite-mux/src/session.rs new file mode 100644 index 0000000..1fe9520 --- /dev/null +++ b/websock-tungstenite-mux/src/session.rs @@ -0,0 +1,525 @@ +use std::collections::HashMap; +use std::sync::{ + atomic::{AtomicBool, AtomicU64, Ordering}, + Arc, +}; + +use bytes::{BufMut, Bytes}; +use futures_util::{SinkExt, StreamExt}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::sync::{mpsc, Mutex}; +use tokio_tungstenite::tungstenite; +use websock_proto::{Error, Result}; + +use websock_mux_proto::stream::{Frame, StreamDir, StreamId}; + +#[derive(Clone)] +pub struct Session { + inner: Arc, + accept_uni: Arc>>, + accept_bi: Arc>>, +} + +impl Session { + pub(crate) fn new( + stream: tokio_tungstenite::WebSocketStream, + is_server: bool, + ) -> Result + where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, + { + let (outbound_tx, outbound_rx) = mpsc::channel(128); + let (accept_uni_tx, accept_uni_rx) = mpsc::channel(128); + let (accept_bi_tx, accept_bi_rx) = mpsc::channel(128); + + let inner = Arc::new(SessionInner::new( + is_server, + outbound_tx, + accept_uni_tx, + accept_bi_tx, + )); + + let session = Self { + inner: inner.clone(), + accept_uni: Arc::new(Mutex::new(accept_uni_rx)), + accept_bi: Arc::new(Mutex::new(accept_bi_rx)), + }; + + inner.spawn_tasks(stream, outbound_rx); + Ok(session) + } + + pub async fn open_uni(&self) -> Result { + let id = self.inner.next_stream_id(StreamDir::Uni)?; + self.inner.send_frame(Frame::OpenUni { id }).await?; + Ok(SendStream::new(id, self.inner.clone())) + } + + pub async fn open_bi(&self) -> Result<(SendStream, RecvStream)> { + let id = self.inner.next_stream_id(StreamDir::Bi)?; + let recv = self.inner.register_recv_stream(id).await; + self.inner.send_frame(Frame::OpenBi { id }).await?; + Ok((SendStream::new(id, self.inner.clone()), recv)) + } + + pub async fn accept_uni(&self) -> Result { + let mut rx = self.accept_uni.lock().await; + rx.recv().await.ok_or(Error::Closed) + } + + pub async fn accept_bi(&self) -> Result<(SendStream, RecvStream)> { + let mut rx = self.accept_bi.lock().await; + rx.recv().await.ok_or(Error::Closed) + } +} + +#[derive(Clone)] +pub struct SendStream { + id: StreamId, + session: Arc, + finished: Arc, +} + +impl SendStream { + fn new(id: StreamId, session: Arc) -> Self { + Self { + id, + session, + finished: Arc::new(AtomicBool::new(false)), + } + } + + pub async fn write(&self, data: &[u8]) -> Result<()> { + self.write_buf(Bytes::copy_from_slice(data)).await + } + + pub async fn write_buf(&self, data: Bytes) -> Result<()> { + self.session + .send_frame(Frame::Stream { + id: self.id, + data, + fin: false, + }) + .await + } + + pub async fn write_all(&self, data: &[u8]) -> Result<()> { + self.write(data).await + } + + pub async fn finish(&self) -> Result<()> { + if self + .finished + .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst) + .is_ok() + { + self.session + .send_frame(Frame::Stream { + id: self.id, + data: Bytes::new(), + fin: true, + }) + .await?; + } + Ok(()) + } + + pub async fn reset(&self, code: u64) -> Result<()> { + self.finished.store(true, Ordering::SeqCst); + self.session + .send_frame(Frame::ResetStream { id: self.id, code }) + .await + } + + pub fn closed(&self) -> bool { + self.finished.load(Ordering::SeqCst) + } +} + +impl Drop for SendStream { + fn drop(&mut self) { + if !self.finished.load(Ordering::SeqCst) { + self.session + .try_send_frame(Frame::ResetStream { id: self.id, code: 0 }); + } + } +} + +#[derive(Debug)] +struct RecvEvent { + data: Bytes, + fin: bool, +} + +pub struct RecvStream { + id: StreamId, + session: Arc, + receiver: mpsc::Receiver, + finished: bool, + pending: Bytes, +} + +impl RecvStream { + fn new(id: StreamId, session: Arc, receiver: mpsc::Receiver) -> Self { + Self { + id, + session, + receiver, + finished: false, + pending: Bytes::new(), + } + } + + pub async fn read(&mut self, buf: &mut [u8]) -> Result> { + if self.finished { + return Ok(None); + } + if self.pending.is_empty() { + if let Some(chunk) = self.read_chunk_internal().await? { + self.pending = chunk; + } else { + return Ok(None); + } + } + let amt = buf.len().min(self.pending.len()); + buf[..amt].copy_from_slice(&self.pending[..amt]); + self.pending = self.pending.slice(amt..); + Ok(Some(amt)) + } + + pub async fn read_buf(&mut self, buf: &mut B) -> Result> { + let mut temp = vec![0u8; 4096]; + let size_opt = self.read(&mut temp).await?; + if let Some(size) = size_opt { + buf.put_slice(&temp[..size]); + Ok(Some(size)) + } else { + Ok(None) + } + } + + pub async fn read_chunk_internal(&mut self) -> Result> { + match self.receiver.recv().await { + Some(event) => { + if event.fin { + self.finished = true; + } + if event.data.is_empty() && event.fin { + Ok(None) + } else { + Ok(Some(event.data)) + } + } + None => { + self.finished = true; + Ok(None) + } + } + } + + pub async fn read_chunk(&mut self, max: usize) -> Result> { + match self.receiver.recv().await { + Some(mut event) => { + if event.fin { + self.finished = true; + } + if event.data.is_empty() && event.fin { + Ok(None) + } else if event.data.len() > max { + let chunk = event.data.split_to(max); + self.pending = event.data; + Ok(Some(chunk)) + } else { + Ok(Some(event.data)) + } + } + None => { + self.finished = true; + Ok(None) + } + } + } + + pub async fn stop(&self, code: u64) -> Result<()> { + self.session + .send_frame(Frame::StopSending { id: self.id, code }) + .await + } + + pub fn closed(&self) -> bool { + self.finished + } +} + +impl Drop for RecvStream { + fn drop(&mut self) { + if !self.finished { + self.session + .try_send_frame(Frame::StopSending { id: self.id, code: 0 }); + } + } +} + +pub(crate) struct SessionInner { + is_server: bool, + outbound_tx: mpsc::Sender, + accept_uni_tx: Mutex>>, + accept_bi_tx: Mutex>>, + streams: Mutex>>, + next_uni: AtomicU64, + next_bi: AtomicU64, + closed: AtomicBool, +} + +impl SessionInner { + pub(crate) fn new( + is_server: bool, + outbound_tx: mpsc::Sender, + accept_uni_tx: mpsc::Sender, + accept_bi_tx: mpsc::Sender<(SendStream, RecvStream)>, + ) -> Self { + Self { + is_server, + outbound_tx, + accept_uni_tx: Mutex::new(Some(accept_uni_tx)), + accept_bi_tx: Mutex::new(Some(accept_bi_tx)), + streams: Mutex::new(HashMap::new()), + next_uni: AtomicU64::new(0), + next_bi: AtomicU64::new(0), + closed: AtomicBool::new(false), + } + } + + fn spawn_tasks( + self: Arc, + stream: tokio_tungstenite::WebSocketStream, + mut outbound_rx: mpsc::Receiver, + ) where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, + { + let (mut ws_sink, mut ws_stream) = stream.split(); + + let inbound = self.clone(); + tokio::spawn(async move { + while let Some(msg) = ws_stream.next().await { + let msg = match msg { + Ok(m) => m, + Err(_) => break, + }; + + match msg { + tungstenite::Message::Binary(data) => { + let mut cursor = &data[..]; + let frame = match Frame::decode(&mut cursor) { + Ok(f) => f, + Err(_) => break, + }; + if inbound.clone().handle_frame(frame).await.is_err() { + break; + } + } + tungstenite::Message::Ping(p) => { + let _ = inbound + .outbound_tx + .try_send(tungstenite::Message::Pong(p)); + } + tungstenite::Message::Close(_) => break, + _ => {} + } + } + + inbound.close_all().await; + }); + + let outbound = self.clone(); + tokio::spawn(async move { + while let Some(msg) = outbound_rx.recv().await { + if ws_sink.send(msg).await.is_err() { + break; + } + } + outbound.close_all().await; + }); + } + + pub(crate) async fn handle_frame(self: Arc, frame: Frame) -> Result<()> { + match frame { + Frame::OpenUni { id } => { + if id.dir() != StreamDir::Uni { + return self + .protocol_error(1, "OpenUni with non-uni StreamId") + .await; + } + if id.initiator_is_server() != self.peer_is_server() { + return self + .protocol_error(1, "OpenUni with wrong initiator") + .await; + } + + { + let streams = self.streams.lock().await; + if streams.contains_key(&id) { + return self.protocol_error(1, "duplicate stream open").await; + } + } + + let recv = self.register_recv_stream(id).await; + + let tx = { self.accept_uni_tx.lock().await.clone() }; + if let Some(tx) = tx { + tx.send(recv).await.map_err(|_| Error::Closed)?; + } else { + return Err(Error::Closed); + } + } + Frame::OpenBi { id } => { + if id.dir() != StreamDir::Bi { + return self + .protocol_error(1, "OpenBi with non-bi StreamId") + .await; + } + if id.initiator_is_server() != self.peer_is_server() { + return self + .protocol_error(1, "OpenBi with wrong initiator") + .await; + } + + { + let streams = self.streams.lock().await; + if streams.contains_key(&id) { + return self.protocol_error(1, "duplicate stream open").await; + } + } + + let recv = self.register_recv_stream(id).await; + let send = SendStream::new(id, self.clone()); + + let tx = { self.accept_bi_tx.lock().await.clone() }; + if let Some(tx) = tx { + tx.send((send, recv)).await.map_err(|_| Error::Closed)?; + } else { + return Err(Error::Closed); + } + } + Frame::Stream { id, data, fin } => { + let mut streams = self.streams.lock().await; + + let Some(tx) = streams.get(&id) else { + return self + .protocol_error(1, "Stream data on unknown stream") + .await; + }; + + let _ = tx.send(RecvEvent { data, fin }).await; + + if fin { + streams.remove(&id); + } + } + Frame::ResetStream { id, .. } | Frame::StopSending { id, .. } => { + let mut streams = self.streams.lock().await; + if streams.remove(&id).is_none() { + return self + .protocol_error(1, "reset/stop on unknown stream") + .await; + } + } + Frame::ConnectionClose { .. } => { + let mut streams = self.streams.lock().await; + streams.clear(); + } + } + Ok(()) + } + + pub(crate) async fn register_recv_stream(self: &Arc, id: StreamId) -> RecvStream { + let (tx, rx) = mpsc::channel(128); + let mut streams = self.streams.lock().await; + streams.insert(id, tx); + RecvStream::new(id, self.clone(), rx) + } + + pub(crate) fn next_stream_id(&self, dir: StreamDir) -> Result { + let counter = match dir { + StreamDir::Uni => self.next_uni.fetch_add(1, Ordering::SeqCst), + StreamDir::Bi => self.next_bi.fetch_add(1, Ordering::SeqCst), + }; + StreamId::new(counter, self.is_server, dir).map_err(|e| Error::StreamId(e.to_string())) + } + + pub(crate) async fn send_frame(&self, frame: Frame) -> Result<()> { + if self.is_closed() { + return Err(Error::Closed); + } + let data = frame.encode().freeze(); + self.outbound_tx + .send(tungstenite::Message::Binary(data)) + .await + .map_err(|_| Error::Closed) + } + + pub(crate) fn try_send_frame(&self, frame: Frame) { + if self.is_closed() { + return; + } + let data = frame.encode().freeze(); + let _ = self.outbound_tx.try_send(tungstenite::Message::Binary(data)); + } + + pub(crate) async fn close_all(&self) { + if self.closed.swap(true, Ordering::SeqCst) { + return; + } + // Close accept channels + { + let mut tx = self.accept_uni_tx.lock().await; + tx.take(); + } + { + let mut tx = self.accept_bi_tx.lock().await; + tx.take(); + } + // Close existing streams + { + let mut streams = self.streams.lock().await; + streams.clear(); + } + } + + pub(crate) fn is_closed(&self) -> bool { + self.closed.load(Ordering::SeqCst) + } + + async fn protocol_error(self: &Arc, code: u64, reason: impl Into) -> Result<()> { + let reason = reason.into(); + + // Notify the peer (it's okay if sending fails...) + let _ = self.try_send_frame(Frame::ConnectionClose { + code, + reason: reason.clone(), + }); + + self.close_all().await; + Err(Error::Protocol(reason)) + } + + fn peer_is_server(&self) -> bool { + !self.is_server + } +} + +/// Map tungstenite errors into the shared error type. +pub(crate) fn map_tungstenite_err(e: tungstenite::Error) -> Error { + use tungstenite::Error as E; + match e { + E::ConnectionClosed | E::AlreadyClosed => Error::Closed, + E::Io(io) => Error::Io(io.to_string()), + E::Tls(tls) => Error::Tls(tls.to_string()), + E::Url(url) => Error::InvalidUrl(url.to_string()), + E::Protocol(err) => Error::Protocol(err.to_string()), + E::Utf8(err) => Error::Protocol(err), + E::Capacity(err) => Error::Protocol(err.to_string()), + E::HttpFormat(err) => Error::Protocol(err.to_string()), + other => Error::Other(other.to_string()), + } +} diff --git a/websock-tungstenite-mux/src/tls/cert.rs b/websock-tungstenite-mux/src/tls/cert.rs new file mode 100644 index 0000000..a1cbf0f --- /dev/null +++ b/websock-tungstenite-mux/src/tls/cert.rs @@ -0,0 +1,96 @@ +//! Certificate handling utilities. + +use rustls::client::danger::ServerCertVerifier; +use rustls::pki_types::{CertificateDer, ServerName, UnixTime}; +use std::{fs, path::Path, sync::Arc}; +use websock_proto::{Error, Result}; + +/// Load native certificates from the host system. +pub fn get_native_certs() -> Result { + let mut root_store = rustls::RootCertStore::empty(); + + let cert_result = rustls_native_certs::load_native_certs(); + + for cert in cert_result.certs { + let _ = root_store.add(cert); + } + + Ok(root_store) +} + +/// Load a certificate chain from a file (.der or .pem). +pub fn load_certs(cert_path: &Path) -> Result>> { + let cert_bytes = fs::read(cert_path).map_err(|e| Error::Io(e.to_string()))?; + + if cert_path.extension().map_or(false, |x| x == "der") { + return Ok(vec![CertificateDer::from(cert_bytes)]); + } + + rustls_pemfile::certs(&mut &*cert_bytes) + .collect::, std::io::Error>>() + .map_err(|e| Error::Io(e.to_string())) +} + +/// Certificate verifier that unconditionally accepts certificates. +/// +/// # Warning +/// This is vulnerable to MITM attacks and must only be used for testing. +#[derive(Debug)] +pub struct SkipServerVerification(Arc); + +impl SkipServerVerification { + /// Create a verifier using the default ring provider. + pub fn new() -> Arc { + Self::with_provider(Arc::new(rustls::crypto::ring::default_provider())) + } + + /// Create a verifier with the provided crypto provider. + pub fn with_provider(provider: Arc) -> Arc { + Arc::new(Self(provider)) + } +} + +impl ServerCertVerifier for SkipServerVerification { + fn verify_server_cert( + &self, + _end_entity: &CertificateDer<'_>, + _intermediates: &[CertificateDer<'_>], + _server_name: &ServerName<'_>, + _ocsp: &[u8], + _now: UnixTime, + ) -> std::result::Result { + Ok(rustls::client::danger::ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> std::result::Result { + rustls::crypto::verify_tls12_signature( + message, + cert, + dss, + &self.0.signature_verification_algorithms, + ) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> std::result::Result { + rustls::crypto::verify_tls13_signature( + message, + cert, + dss, + &self.0.signature_verification_algorithms, + ) + } + + fn supported_verify_schemes(&self) -> Vec { + self.0.signature_verification_algorithms.supported_schemes() + } +} diff --git a/websock-tungstenite-mux/src/tls/key.rs b/websock-tungstenite-mux/src/tls/key.rs new file mode 100644 index 0000000..10c8987 --- /dev/null +++ b/websock-tungstenite-mux/src/tls/key.rs @@ -0,0 +1,22 @@ +//! Private key handling utilities. + +use rustls::pki_types::{PrivateKeyDer, PrivatePkcs8KeyDer}; +use std::{fs, path::Path}; +use websock_proto::{Error, Result}; + +/// Load a private key from a file. +pub fn load_key(key_path: &Path) -> Result> { + let key = fs::read(key_path).map_err(|e| Error::Io(e.to_string()))?; + + let key = if key_path.extension().map_or(false, |x| x == "der") { + // Treat raw DER as PKCS#8. + PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key)) + } else { + // Decode PEM. + rustls_pemfile::private_key(&mut &*key) + .map_err(|e| Error::Tls(e.to_string()))? + .ok_or_else(|| Error::Io("no keys found".into()))? + }; + + Ok(key) +} diff --git a/websock-tungstenite-mux/src/tls/mod.rs b/websock-tungstenite-mux/src/tls/mod.rs new file mode 100644 index 0000000..6079528 --- /dev/null +++ b/websock-tungstenite-mux/src/tls/mod.rs @@ -0,0 +1,259 @@ +//! TLS configuration helpers for managing certificates and private keys. + +pub mod cert; +pub mod key; + +use websock_proto::{Error, Result}; + +use cert::SkipServerVerification; +use rustls::client::{ClientConfig, WebPkiServerVerifier}; +use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}; +use rustls::server::ServerConfig; +use std::path::Path; +use std::sync::Arc; + +/// Alias of [`rustls::server::ServerConfig`]. +pub type TlsServerConfig = rustls::server::ServerConfig; + +/// Alias of [`rustls::client::ClientConfig`]. +pub type TlsClientConfig = rustls::client::ClientConfig; + +/// Generate a self-signed certificate and private key (DER). +pub fn generate_self_signed_pair_der( + subject_alt_names: Vec, +) -> Result<(Vec>, PrivateKeyDer<'static>)> { + let cert = rcgen::generate_simple_self_signed(subject_alt_names) + .map_err(|e| Error::Tls(e.to_string()))?; + + let key = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(cert.signing_key.serialize_der())); + let cert_chain = vec![CertificateDer::from(cert.cert)]; + Ok((cert_chain, key)) +} + +/// Generate a self-signed certificate and private key (PEM). +pub fn generate_self_signed_pair_pem( + subject_alt_names: Vec, +) -> Result<(Vec, String)> { + let cert = rcgen::generate_simple_self_signed(subject_alt_names) + .map_err(|e| Error::Tls(e.to_string()))?; + + let key = cert.signing_key.serialize_pem(); + let cert_chain = vec![cert.cert.pem()]; + Ok((cert_chain, key)) +} + +/// Load certificate chain and private key from files. +pub fn load_cert( + cert_path: &Path, + key_path: &Path, +) -> Result<(Vec>, PrivateKeyDer<'static>)> { + let cert_chain = cert::load_certs(cert_path)?; + let key = key::load_key(key_path)?; + Ok((cert_chain, key)) +} + +/// Bundled TLS configuration for both client and server use. +#[derive(Debug, Clone)] +pub struct TlsConfig { + /// Client-side rustls configuration. + pub client_config: ClientConfig, + /// Server-side rustls configuration. + pub server_config: ServerConfig, +} + +impl TlsConfig { + /// Create a new TLS configuration with the specified certificate and private key. + pub fn with_cert(cert_path: &Path, key_path: &Path) -> Result { + let client_config = TlsClientConfigBuilder::new_with_native_certs()? + .with_alpn_protocols(vec![b"h3".to_vec()]) + .build(); + + let server_config = TlsServerConfigBuilder::new_with_cert(cert_path, key_path)? + .with_alpn_protocols(vec![b"h3".to_vec()]) + .build(); + + Ok(Self { + client_config, + server_config, + }) + } + + /// Create a new TLS configuration with self-signed certificates (localhost). + pub fn with_self_signed_certs() -> Result { + let client_config = TlsClientConfigBuilder::new_with_native_certs()? + .with_alpn_protocols(vec![b"h3".to_vec()]) + .build(); + + let server_config = + TlsServerConfigBuilder::new_with_self_signed_certs(vec!["localhost".into()])? + .with_alpn_protocols(vec![b"h3".to_vec()]) + .build(); + + Ok(Self { + client_config, + server_config, + }) + } + + /// Create a new TLS configuration with system certificates (server side is self-signed localhost). + pub fn new_native_config() -> Result { + let client_config = TlsClientConfigBuilder::new_with_native_certs()? + .with_alpn_protocols(vec![b"h3".to_vec()]) + .build(); + + let server_config = + TlsServerConfigBuilder::new_with_self_signed_certs(vec!["localhost".into()])? + .with_alpn_protocols(vec![b"h3".to_vec()]) + .build(); + + Ok(Self { + client_config, + server_config, + }) + } + + /// Create a new TLS configuration with no certificate verification (testing only). + pub fn new_insecure_config() -> Result { + let client_config = TlsClientConfigBuilder::new_insecure()? + .with_alpn_protocols(vec![b"h3".to_vec()]) + .build(); + + let server_config = + TlsServerConfigBuilder::new_with_self_signed_certs(vec!["localhost".into()])? + .with_alpn_protocols(vec![b"h3".to_vec()]) + .build(); + + Ok(Self { + client_config, + server_config, + }) + } +} + +/// Server config builder (owned builder). +#[derive(Debug, Clone)] +pub struct TlsServerConfigBuilder { + inner: TlsServerConfig, +} + +impl TlsServerConfigBuilder { + /// Create an insecure server config with a self-signed certificate. + pub fn new_insecure(subject_alt_names: Vec) -> Result { + let (certs, key) = generate_self_signed_pair_der(subject_alt_names)?; + let inner = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(certs, key) + .map_err(|e| Error::Tls(e.to_string()))?; + Ok(Self { inner }) + } + + /// Create a server config using certificate and private key files. + pub fn new_with_cert(cert_path: &Path, key_path: &Path) -> Result { + let (certs, key) = load_cert(cert_path, key_path)?; + let inner = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(certs, key) + .map_err(|e| Error::Tls(e.to_string()))?; + Ok(Self { inner }) + } + + /// Create a server config using a self-signed certificate. + pub fn new_with_self_signed_certs(subject_alt_names: Vec) -> Result { + Self::new_insecure(subject_alt_names) + } + + /// Set ALPN protocol identifiers. + pub fn with_alpn_protocols(mut self, protocols: Vec>) -> Self { + self.inner.alpn_protocols = protocols; + self + } + + /// Finalize the builder and return the server config. + pub fn build(self) -> TlsServerConfig { + self.inner + } +} + +/// Client config builder (owned builder). +#[derive(Debug, Clone)] +pub struct TlsClientConfigBuilder { + inner: TlsClientConfig, +} + +impl TlsClientConfigBuilder { + /// Create an insecure client config that skips certificate verification. + pub fn new_insecure() -> Result { + let inner = ClientConfig::builder() + .dangerous() + .with_custom_certificate_verifier(SkipServerVerification::new()) + .with_no_client_auth(); + Ok(Self { inner }) + } + + /// Create a client config that uses the system root store. + pub fn new_with_native_certs() -> Result { + let native_certs = cert::get_native_certs()?; + let inner = ClientConfig::builder() + .with_root_certificates(native_certs) + .with_no_client_auth(); + Ok(Self { inner }) + } + + /// Create a client config with a custom WebPKI verifier. + pub fn new_with_webpki_verifier(verifier: Arc) -> Result { + let inner = ClientConfig::builder() + .with_webpki_verifier(verifier) + .with_no_client_auth(); + Ok(Self { inner }) + } + + /// Set ALPN protocol identifiers. + pub fn with_alpn_protocols(mut self, protocols: Vec>) -> Self { + self.inner.alpn_protocols = protocols; + self + } + + /// Finalize the builder and return the client config. + pub fn build(self) -> TlsClientConfig { + self.inner + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_generate_self_signed_pair_der() { + let (cert_chain, key) = generate_self_signed_pair_der(vec!["localhost".into()]).unwrap(); + let rustls_server_config = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(cert_chain, key); + + if let Err(e) = rustls_server_config { + panic!("Failed to create ServerConfig: {e}"); + } + } + + #[test] + fn test_generate_self_signed_pair_pem() { + let (cert_chain, key) = generate_self_signed_pair_pem(vec!["localhost".into()]).unwrap(); + + let cert_path = Path::new("cert.pem"); + let key_path = Path::new("key.pem"); + std::fs::write(cert_path, cert_chain.join("\n")).unwrap(); + std::fs::write(key_path, key).unwrap(); + + let (cert_chain, key) = load_cert(cert_path, key_path).unwrap(); + let rustls_server_config = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(cert_chain, key); + + if let Err(e) = rustls_server_config { + panic!("Failed to create ServerConfig: {e}"); + } + + std::fs::remove_file(cert_path).unwrap(); + std::fs::remove_file(key_path).unwrap(); + } +} diff --git a/websock-wasm-mux/Cargo.toml b/websock-wasm-mux/Cargo.toml new file mode 100644 index 0000000..e413eab --- /dev/null +++ b/websock-wasm-mux/Cargo.toml @@ -0,0 +1,7 @@ +[package] +name = "websock-wasm-mux" +version.workspace = true +edition.workspace = true +authors.workspace = true + +[dependencies] diff --git a/websock-wasm-mux/src/lib.rs b/websock-wasm-mux/src/lib.rs new file mode 100644 index 0000000..d951046 --- /dev/null +++ b/websock-wasm-mux/src/lib.rs @@ -0,0 +1,3 @@ +pub fn todo() { + unimplemented!() +} From ccaf87e81c1ba9f5e98fe4ccefc64fa2e812ffea Mon Sep 17 00:00:00 2001 From: shellrow Date: Sat, 17 Jan 2026 14:11:24 +0900 Subject: [PATCH 02/16] Update Cargo.toml --- websock-mux-proto/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/websock-mux-proto/Cargo.toml b/websock-mux-proto/Cargo.toml index 5f18011..cb83977 100644 --- a/websock-mux-proto/Cargo.toml +++ b/websock-mux-proto/Cargo.toml @@ -16,4 +16,4 @@ thiserror = "2" websock-proto = { workspace = true } [target.'cfg(not(target_arch = "wasm32"))'.dependencies] -tokio = { version = "1", features = ["rt", "sync", "macros"] } +tokio = { version = "1", features = ["io-util"] } From c37fe88d5b9a22654e7b72f48f3fe6561eb5161a Mon Sep 17 00:00:00 2001 From: shellrow Date: Sat, 17 Jan 2026 14:12:43 +0900 Subject: [PATCH 03/16] Format code with cargo fmt --- websock-mux-proto/src/stream.rs | 39 ++++++++++++++++++++++++--------- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/websock-mux-proto/src/stream.rs b/websock-mux-proto/src/stream.rs index bc6a236..74886c0 100644 --- a/websock-mux-proto/src/stream.rs +++ b/websock-mux-proto/src/stream.rs @@ -12,15 +12,17 @@ pub enum StreamDir { pub struct StreamId(pub u64); impl StreamId { - pub fn new(counter: u64, is_server: bool, dir: StreamDir) -> Result { + pub fn new( + counter: u64, + is_server: bool, + dir: StreamDir, + ) -> Result { let initiator = if is_server { 1 } else { 0 }; let dir_bit = match dir { StreamDir::Bi => 0, StreamDir::Uni => 1, }; - let value = counter - .checked_shl(2) - .ok_or(VarIntBoundsExceeded)? + let value = counter.checked_shl(2).ok_or(VarIntBoundsExceeded)? | ((dir_bit as u64) << 1) | (initiator as u64); VarInt::from_u64(value)?; @@ -42,12 +44,29 @@ impl StreamId { #[derive(Debug, Clone, PartialEq, Eq)] pub enum Frame { - OpenUni { id: StreamId }, - OpenBi { id: StreamId }, - Stream { id: StreamId, data: Bytes, fin: bool }, - ResetStream { id: StreamId, code: u64 }, - StopSending { id: StreamId, code: u64 }, - ConnectionClose { code: u64, reason: String }, + OpenUni { + id: StreamId, + }, + OpenBi { + id: StreamId, + }, + Stream { + id: StreamId, + data: Bytes, + fin: bool, + }, + ResetStream { + id: StreamId, + code: u64, + }, + StopSending { + id: StreamId, + code: u64, + }, + ConnectionClose { + code: u64, + reason: String, + }, } impl Frame { From 31756cf7adb1bf10e4647ca9f6f519cbb53ca992 Mon Sep 17 00:00:00 2001 From: shellrow Date: Sat, 17 Jan 2026 14:13:08 +0900 Subject: [PATCH 04/16] Add tests --- websock-mux-proto/tests/frame.rs | 107 ++++++++++++++++++++++++++++++ websock-mux-proto/tests/varint.rs | 55 +++++++++++++++ 2 files changed, 162 insertions(+) create mode 100644 websock-mux-proto/tests/frame.rs create mode 100644 websock-mux-proto/tests/varint.rs diff --git a/websock-mux-proto/tests/frame.rs b/websock-mux-proto/tests/frame.rs new file mode 100644 index 0000000..3b34056 --- /dev/null +++ b/websock-mux-proto/tests/frame.rs @@ -0,0 +1,107 @@ +use bytes::{Bytes, BytesMut}; +use std::io::Cursor; + +use websock_mux_proto::{Frame, FrameDecodeError, StreamDir, StreamId, VarInt}; + +fn roundtrip(frame: &Frame) { + let encoded = frame.encode().freeze(); + let mut cur = Cursor::new(encoded); + let decoded = Frame::decode(&mut cur).unwrap(); + assert_eq!(&decoded, frame); +} + +#[test] +fn frame_roundtrip_all_variants() { + let bi_client = StreamId::new(0, false, StreamDir::Bi).unwrap(); + let uni_client = StreamId::new(1, false, StreamDir::Uni).unwrap(); + let bi_server = StreamId::new(2, true, StreamDir::Bi).unwrap(); + + let frames = vec![ + Frame::OpenUni { id: uni_client }, + Frame::OpenBi { id: bi_client }, + Frame::Stream { + id: bi_client, + data: Bytes::from_static(b"hello"), + fin: false, + }, + Frame::Stream { + id: bi_server, + data: Bytes::from(vec![0u8; 1024]), + fin: true, + }, + Frame::ResetStream { + id: bi_client, + code: 0, + }, + Frame::ResetStream { + id: bi_client, + code: 0xdead_beef, + }, + Frame::StopSending { + id: bi_server, + code: 42, + }, + Frame::ConnectionClose { + code: 0, + reason: "bye".to_string(), + }, + Frame::ConnectionClose { + code: 100, + reason: "test".to_string(), + }, + ]; + + for f in frames { + roundtrip(&f); + } +} + +#[test] +fn frame_decode_unknown_tag() { + let mut buf = BytesMut::new(); + VarInt::from_u32(99).encode(&mut buf); // unknown tag + let mut cur = Cursor::new(buf.freeze()); + + let err = Frame::decode(&mut cur).unwrap_err(); + match err { + FrameDecodeError::UnknownTag(99) => {} + other => panic!("unexpected error: {other:?}"), + } +} + +#[test] +fn frame_decode_unexpected_end_stream_payload() { + // Build a Stream frame but truncate the payload. + let id = StreamId::new(0, false, StreamDir::Bi).unwrap(); + + let mut buf = BytesMut::new(); + VarInt::from_u32(2).encode(&mut buf); // Stream + VarInt::from_u64(id.0).unwrap().encode(&mut buf); + VarInt::from_u32(0).encode(&mut buf); // fin = false + VarInt::from_u32(5).encode(&mut buf); // len=5 + buf.extend_from_slice(b"he"); // only 2 bytes, should be 5 + + let mut cur = Cursor::new(buf.freeze()); + let err = Frame::decode(&mut cur).unwrap_err(); + match err { + FrameDecodeError::UnexpectedEnd => {} + other => panic!("unexpected error: {other:?}"), + } +} + +#[test] +fn frame_decode_invalid_utf8_reason() { + // ConnectionClose: tag=5, code=0, len=2, bytes=[0xFF,0xFF] + let mut buf = BytesMut::new(); + VarInt::from_u32(5).encode(&mut buf); + VarInt::from_u32(0).encode(&mut buf); + VarInt::from_u32(2).encode(&mut buf); + buf.extend_from_slice(&[0xFF, 0xFF]); + + let mut cur = Cursor::new(buf.freeze()); + let err = Frame::decode(&mut cur).unwrap_err(); + match err { + FrameDecodeError::InvalidUtf8 => {} + other => panic!("unexpected error: {other:?}"), + } +} diff --git a/websock-mux-proto/tests/varint.rs b/websock-mux-proto/tests/varint.rs new file mode 100644 index 0000000..4b68074 --- /dev/null +++ b/websock-mux-proto/tests/varint.rs @@ -0,0 +1,55 @@ +use bytes::BytesMut; +use std::io::Cursor; + +use websock_mux_proto::{VarInt, VarIntBoundsExceeded}; + +#[test] +fn varint_roundtrip_boundaries() { + // (value, expected_encoded_size) + let cases: &[(u64, usize)] = &[ + (0, 1), + (1, 1), + (63, 1), + (64, 2), + (16383, 2), + (16384, 4), + ((1u64 << 30) - 1, 4), + (1u64 << 30, 8), + ((1u64 << 62) - 1, 8), + ]; + + for &(value, expected_size) in cases { + let v = VarInt::from_u64(value).unwrap(); + assert_eq!(v.size(), expected_size, "size mismatch for {}", value); + + let mut buf = BytesMut::new(); + v.encode(&mut buf); + + // encoded length should match + assert_eq!( + buf.len(), + expected_size, + "encoded length mismatch for {}", + value + ); + + let mut cur = Cursor::new(buf.freeze()); + let decoded = VarInt::decode(&mut cur).unwrap(); + assert_eq!( + decoded.into_inner(), + value, + "roundtrip mismatch for {}", + value + ); + + // should consume all + assert_eq!(cur.position() as usize, expected_size); + } +} + +#[test] +fn varint_rejects_too_large() { + let too_large = 1u64 << 62; + let err = VarInt::from_u64(too_large).unwrap_err(); + assert_eq!(err, VarIntBoundsExceeded); +} From e690786ce885e668ef752f15b02d25f151d6c197 Mon Sep 17 00:00:00 2001 From: shellrow Date: Sat, 17 Jan 2026 14:15:25 +0900 Subject: [PATCH 05/16] Format code with cargo fmt --- websock-mux/examples/echo-client-mux.rs | 4 ++-- websock-mux/examples/echo-server-mux.rs | 11 ++++++----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/websock-mux/examples/echo-client-mux.rs b/websock-mux/examples/echo-client-mux.rs index 4f57cba..0ffaff7 100644 --- a/websock-mux/examples/echo-client-mux.rs +++ b/websock-mux/examples/echo-client-mux.rs @@ -3,14 +3,14 @@ //! This demonstrates opening a bidirectional stream and echoing bytes. use clap::Parser; -use rustls::{client::ClientConfig, RootCertStore}; +use rustls::{RootCertStore, client::ClientConfig}; use std::{path, sync::Arc}; use tracing::Level; use tracing_subscriber::FmtSubscriber; use url::Url; use websock_mux::{ - tls::{self, TlsClientConfigBuilder}, Client, ClientBuilder, + tls::{self, TlsClientConfigBuilder}, }; const DEFAULT_ECHO_URL: &str = "ws://127.0.0.1:9001"; diff --git a/websock-mux/examples/echo-server-mux.rs b/websock-mux/examples/echo-server-mux.rs index 456323e..606f970 100644 --- a/websock-mux/examples/echo-server-mux.rs +++ b/websock-mux/examples/echo-server-mux.rs @@ -3,7 +3,7 @@ //! This server accepts mux sessions and echoes bytes over each bidirectional stream. use clap::Parser; -use std::{io, fs, path}; +use std::{fs, io, path}; use tracing::Level; use tracing_subscriber::FmtSubscriber; use websock_mux::{Server, ServerBuilder}; @@ -50,9 +50,7 @@ async fn main() -> anyhow::Result<()> { tracing::info!("TLS enabled (wss://)"); } (None, None) => { - tracing::warn!( - "TLS disabled (ws://). Provide --tls-cert/--tls-key to enable wss://" - ); + tracing::warn!("TLS disabled (ws://). Provide --tls-cert/--tls-key to enable wss://"); } _ => anyhow::bail!("both --tls-cert and --tls-key must be provided together, or neither"), } @@ -92,7 +90,10 @@ async fn main() -> anyhow::Result<()> { fn load_pem_cert_and_key( cert_path: &path::Path, key_path: &path::Path, -) -> anyhow::Result<(Vec>, rustls::pki_types::PrivateKeyDer<'static>)> { +) -> anyhow::Result<( + Vec>, + rustls::pki_types::PrivateKeyDer<'static>, +)> { let chain_file = fs::File::open(cert_path)?; let mut chain_reader = io::BufReader::new(chain_file); let chain: Vec> = From d82b371d6b0aa5320718a3c8de7bad3172bc5168 Mon Sep 17 00:00:00 2001 From: shellrow Date: Sat, 17 Jan 2026 14:19:18 +0900 Subject: [PATCH 06/16] Add limits --- websock-tungstenite-mux/src/builder.rs | 36 +++++- websock-tungstenite-mux/src/client.rs | 18 ++- websock-tungstenite-mux/src/lib.rs | 5 +- websock-tungstenite-mux/src/server.rs | 75 ++++++------ websock-tungstenite-mux/src/session.rs | 155 ++++++++++++++++++++----- 5 files changed, 209 insertions(+), 80 deletions(-) diff --git a/websock-tungstenite-mux/src/builder.rs b/websock-tungstenite-mux/src/builder.rs index 87eeaa6..9ceef51 100644 --- a/websock-tungstenite-mux/src/builder.rs +++ b/websock-tungstenite-mux/src/builder.rs @@ -1,8 +1,9 @@ -use websock_proto::Result; -use std::sync::Arc; use rustls::{ClientConfig, ServerConfig}; +use std::sync::Arc; +use websock_proto::Result; -use crate::{bind, Client, Server}; +use crate::session::Limits; +use crate::{Client, Server, bind}; use websock_mux_proto::SUBPROTOCOL; /// Builder for a mux WebSocket client. @@ -10,6 +11,7 @@ use websock_mux_proto::SUBPROTOCOL; pub struct ClientBuilder { pub(crate) opts: websock_proto::ConnectOptions, pub(crate) tls: Option>, + pub(crate) limits: Limits, } impl Default for ClientBuilder { @@ -22,7 +24,11 @@ impl ClientBuilder { pub fn new() -> Self { let mut opts = websock_proto::ConnectOptions::default(); opts.protocols.insert(0, SUBPROTOCOL.to_string()); - Self { opts, tls: None } + Self { + opts, + tls: None, + limits: Limits::default(), + } } pub fn with_options(mut self, opts: websock_proto::ConnectOptions) -> Self { @@ -39,6 +45,12 @@ impl ClientBuilder { self } + /// Configure session limits + pub fn with_limits(mut self, limits: Limits) -> Self { + self.limits = limits; + self + } + pub fn with_header(mut self, name: impl Into, value: impl Into) -> Self { self.opts.headers.push((name.into(), value.into())); self @@ -53,6 +65,7 @@ impl ClientBuilder { Client { opts: self.opts, tls: self.tls, + limits: self.limits, } } } @@ -62,13 +75,18 @@ impl ClientBuilder { pub struct ServerBuilder { pub(crate) opts: websock_proto::ServerOptions, pub(crate) tls: Option, + pub(crate) limits: Limits, } impl ServerBuilder { pub fn new() -> Self { let mut opts = websock_proto::ServerOptions::default(); opts.protocols.push(SUBPROTOCOL.to_string()); - Self { opts, tls: None } + Self { + opts, + tls: None, + limits: Limits::default(), + } } pub fn with_options(mut self, opts: websock_proto::ServerOptions) -> Self { @@ -85,6 +103,12 @@ impl ServerBuilder { self } + /// Configure session limits + pub fn with_limits(mut self, limits: Limits) -> Self { + self.limits = limits; + self + } + pub fn with_header(mut self, name: impl Into, value: impl Into) -> Self { self.opts.headers.push((name.into(), value.into())); self @@ -99,6 +123,6 @@ impl ServerBuilder { where A: tokio::net::ToSocketAddrs, { - bind(addr, self.opts, self.tls).await + bind(addr, self.opts, self.tls, self.limits).await } } diff --git a/websock-tungstenite-mux/src/client.rs b/websock-tungstenite-mux/src/client.rs index 62c8e60..cb46ec0 100644 --- a/websock-tungstenite-mux/src/client.rs +++ b/websock-tungstenite-mux/src/client.rs @@ -9,8 +9,9 @@ use rustls::ClientConfig; use tokio_tungstenite::Connector; use websock_proto::{Error, Result}; -use crate::session::map_tungstenite_err; use crate::Session; +use crate::session::Limits; +use crate::session::map_tungstenite_err; use websock_mux_proto::SUBPROTOCOL; fn negotiated_protocol(resp: &http::Response>>) -> Option<&str> { @@ -37,6 +38,7 @@ fn validate_client_protocols(opts: &websock_proto::ConnectOptions) -> Result<()> pub struct Client { pub(crate) opts: websock_proto::ConnectOptions, pub(crate) tls: Option>, + pub(crate) limits: Limits, } impl Client { @@ -86,14 +88,10 @@ impl Client { } let connector = tls.map(Connector::Rustls); - let (stream, response) = tokio_tungstenite::connect_async_tls_with_config( - request, - None, - false, - connector, - ) - .await - .map_err(map_tungstenite_err)?; + let (stream, response) = + tokio_tungstenite::connect_async_tls_with_config(request, None, false, connector) + .await + .map_err(map_tungstenite_err)?; let proto = negotiated_protocol(&response) .ok_or_else(|| Error::Protocol("missing SEC_WEBSOCKET_PROTOCOL in response".into()))?; @@ -101,6 +99,6 @@ impl Client { return Err(Error::Protocol(format!("subprotocol mismatch: {proto}"))); } - Session::new(stream, false) + Session::new(stream, false, self.limits.clone()) } } diff --git a/websock-tungstenite-mux/src/lib.rs b/websock-tungstenite-mux/src/lib.rs index 70048df..3b2bc0e 100644 --- a/websock-tungstenite-mux/src/lib.rs +++ b/websock-tungstenite-mux/src/lib.rs @@ -11,7 +11,8 @@ pub mod tls; pub use builder::{ClientBuilder, ServerBuilder}; pub use client::Client; -pub use server::{bind, Server}; +pub use server::{Server, bind}; +pub use session::Limits; pub use session::{RecvStream, SendStream, Session}; pub use tls::{ TlsClientConfig, TlsClientConfigBuilder, TlsConfig, TlsServerConfig, TlsServerConfigBuilder, @@ -23,6 +24,7 @@ mod tests { use tokio::sync::mpsc; use websock_mux_proto::{Frame, StreamDir, StreamId}; + use crate::session::Limits; use crate::session::SessionInner; #[test] @@ -45,6 +47,7 @@ mod tests { let (accept_bi_tx, _accept_bi_rx) = mpsc::channel(4); let inner = std::sync::Arc::new(SessionInner::new( false, + Limits::default(), outbound_tx, accept_uni_tx, accept_bi_tx, diff --git a/websock-tungstenite-mux/src/server.rs b/websock-tungstenite-mux/src/server.rs index 526073d..744fc38 100644 --- a/websock-tungstenite-mux/src/server.rs +++ b/websock-tungstenite-mux/src/server.rs @@ -1,17 +1,18 @@ use std::sync::Arc; -use tokio_tungstenite::tungstenite; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::{TcpListener, ToSocketAddrs}; use tokio_rustls::TlsAcceptor; use tokio_tungstenite::accept_hdr_async; +use tokio_tungstenite::tungstenite; use tokio_tungstenite::tungstenite::handshake::server; use tungstenite::http; use tungstenite::http::header::{HeaderName, HeaderValue, SEC_WEBSOCKET_PROTOCOL}; use websock_proto::{Error, Result}; -use crate::session::map_tungstenite_err; use crate::Session; +use crate::session::Limits; +use crate::session::map_tungstenite_err; use websock_mux_proto::SUBPROTOCOL; /// Marker trait for IO types usable by the server. @@ -64,6 +65,7 @@ pub async fn bind( addr: A, opts: websock_proto::ServerOptions, tls: Option, + limits: Limits, ) -> Result where A: ToSocketAddrs, @@ -75,7 +77,9 @@ where validate_protocols(&opts)?; if !opts.protocols.iter().any(|p| p == SUBPROTOCOL) { - return Err(Error::Protocol("SUBPROTOCOL missing in ServerOptions".into())); + return Err(Error::Protocol( + "SUBPROTOCOL missing in ServerOptions".into(), + )); } let acceptor = tls.map(|cfg| TlsAcceptor::from(Arc::new(cfg))); @@ -85,6 +89,7 @@ where opts, headers, acceptor, + limits, }) } @@ -93,6 +98,7 @@ pub struct Server { opts: websock_proto::ServerOptions, headers: Vec<(HeaderName, HeaderValue)>, acceptor: Option, + limits: Limits, } impl Server { @@ -116,39 +122,42 @@ impl Server { let headers = self.headers.clone(); let allowed = self.opts.protocols.clone(); - let ws = accept_hdr_async(stream, move |req: &server::Request, mut resp: server::Response| { - // Additional headers from configuration - for (name, value) in &headers { - resp.headers_mut().append(name, value.clone()); - } - - // websock-mux is required protocol - let Some(protocol) = select_protocol(req, &allowed) else { - return Err(http::Response::builder() - .status(http::StatusCode::BAD_REQUEST) - .body(Some(format!("'{SUBPROTOCOL}' protocol required"))) - .unwrap()); - }; - - // Confirm that the required protocol is present - if !protocol.eq_ignore_ascii_case(SUBPROTOCOL) { - return Err(http::Response::builder() - .status(http::StatusCode::BAD_REQUEST) - .body(Some(format!("'{SUBPROTOCOL}' protocol required"))) - .unwrap()); - } - - resp.headers_mut().insert( - http::header::SEC_WEBSOCKET_PROTOCOL, - http::HeaderValue::from_str(&protocol).expect("validated"), - ); - - Ok(resp) - }) + let ws = accept_hdr_async( + stream, + move |req: &server::Request, mut resp: server::Response| { + // Additional headers from configuration + for (name, value) in &headers { + resp.headers_mut().append(name, value.clone()); + } + + // websock-mux is required protocol + let Some(protocol) = select_protocol(req, &allowed) else { + return Err(http::Response::builder() + .status(http::StatusCode::BAD_REQUEST) + .body(Some(format!("'{SUBPROTOCOL}' protocol required"))) + .unwrap()); + }; + + // Confirm that the required protocol is present + if !protocol.eq_ignore_ascii_case(SUBPROTOCOL) { + return Err(http::Response::builder() + .status(http::StatusCode::BAD_REQUEST) + .body(Some(format!("'{SUBPROTOCOL}' protocol required"))) + .unwrap()); + } + + resp.headers_mut().insert( + http::header::SEC_WEBSOCKET_PROTOCOL, + http::HeaderValue::from_str(&protocol).expect("validated"), + ); + + Ok(resp) + }, + ) .await .map_err(map_tungstenite_err)?; - Session::new(ws, true) + Session::new(ws, true, self.limits.clone()) } pub fn local_addr(&self) -> Result { diff --git a/websock-tungstenite-mux/src/session.rs b/websock-tungstenite-mux/src/session.rs index 1fe9520..f4b5bdf 100644 --- a/websock-tungstenite-mux/src/session.rs +++ b/websock-tungstenite-mux/src/session.rs @@ -1,18 +1,52 @@ use std::collections::HashMap; use std::sync::{ - atomic::{AtomicBool, AtomicU64, Ordering}, Arc, + atomic::{AtomicBool, AtomicU64, Ordering}, }; use bytes::{BufMut, Bytes}; use futures_util::{SinkExt, StreamExt}; use tokio::io::{AsyncRead, AsyncWrite}; -use tokio::sync::{mpsc, Mutex}; +use tokio::sync::{Mutex, mpsc}; use tokio_tungstenite::tungstenite; use websock_proto::{Error, Result}; use websock_mux_proto::stream::{Frame, StreamDir, StreamId}; +/// Session limits to prevent unbounded buffering / DoS. +#[derive(Debug, Clone)] +pub struct Limits { + /// Maximum size of a single WebSocket binary message accepted by the inbound task. + pub max_ws_message_size: usize, + /// Maximum `Stream` frame payload size. + pub max_stream_data_per_frame: usize, + /// Maximum number of concurrently open receive streams. + pub max_open_streams: usize, + /// Per-stream receive event queue length. + pub recv_event_queue_len: usize, + /// Session outbound queue length. + pub outbound_queue_len: usize, + /// Queue length for accepting inbound uni streams. + pub accept_uni_queue_len: usize, + /// Queue length for accepting inbound bi streams. + pub accept_bi_queue_len: usize, +} + +impl Default for Limits { + fn default() -> Self { + Self { + // Safe defaults for a WebSocket fallback transport. + max_ws_message_size: 1 * 1024 * 1024, // 1 MiB + max_stream_data_per_frame: 256 * 1024, // 256 KiB + max_open_streams: 1024, + recv_event_queue_len: 128, + outbound_queue_len: 256, + accept_uni_queue_len: 128, + accept_bi_queue_len: 128, + } + } +} + #[derive(Clone)] pub struct Session { inner: Arc, @@ -24,16 +58,18 @@ impl Session { pub(crate) fn new( stream: tokio_tungstenite::WebSocketStream, is_server: bool, + limits: Limits, ) -> Result where S: AsyncRead + AsyncWrite + Unpin + Send + 'static, { - let (outbound_tx, outbound_rx) = mpsc::channel(128); - let (accept_uni_tx, accept_uni_rx) = mpsc::channel(128); - let (accept_bi_tx, accept_bi_rx) = mpsc::channel(128); + let (outbound_tx, outbound_rx) = mpsc::channel(limits.outbound_queue_len); + let (accept_uni_tx, accept_uni_rx) = mpsc::channel(limits.accept_uni_queue_len); + let (accept_bi_tx, accept_bi_rx) = mpsc::channel(limits.accept_bi_queue_len); let inner = Arc::new(SessionInner::new( is_server, + limits, outbound_tx, accept_uni_tx, accept_bi_tx, @@ -139,8 +175,10 @@ impl SendStream { impl Drop for SendStream { fn drop(&mut self) { if !self.finished.load(Ordering::SeqCst) { - self.session - .try_send_frame(Frame::ResetStream { id: self.id, code: 0 }); + self.session.try_send_frame(Frame::ResetStream { + id: self.id, + code: 0, + }); } } } @@ -254,14 +292,17 @@ impl RecvStream { impl Drop for RecvStream { fn drop(&mut self) { if !self.finished { - self.session - .try_send_frame(Frame::StopSending { id: self.id, code: 0 }); + self.session.try_send_frame(Frame::StopSending { + id: self.id, + code: 0, + }); } } } pub(crate) struct SessionInner { is_server: bool, + limits: Limits, outbound_tx: mpsc::Sender, accept_uni_tx: Mutex>>, accept_bi_tx: Mutex>>, @@ -274,12 +315,14 @@ pub(crate) struct SessionInner { impl SessionInner { pub(crate) fn new( is_server: bool, + limits: Limits, outbound_tx: mpsc::Sender, accept_uni_tx: mpsc::Sender, accept_bi_tx: mpsc::Sender<(SendStream, RecvStream)>, ) -> Self { Self { is_server, + limits, outbound_tx, accept_uni_tx: Mutex::new(Some(accept_uni_tx)), accept_bi_tx: Mutex::new(Some(accept_bi_tx)), @@ -309,6 +352,10 @@ impl SessionInner { match msg { tungstenite::Message::Binary(data) => { + if data.len() > inbound.limits.max_ws_message_size { + let _ = inbound.protocol_error(2, "ws message too large").await; + break; + } let mut cursor = &data[..]; let frame = match Frame::decode(&mut cursor) { Ok(f) => f, @@ -319,9 +366,7 @@ impl SessionInner { } } tungstenite::Message::Ping(p) => { - let _ = inbound - .outbound_tx - .try_send(tungstenite::Message::Pong(p)); + let _ = inbound.outbound_tx.try_send(tungstenite::Message::Pong(p)); } tungstenite::Message::Close(_) => break, _ => {} @@ -351,13 +396,21 @@ impl SessionInner { .await; } if id.initiator_is_server() != self.peer_is_server() { - return self - .protocol_error(1, "OpenUni with wrong initiator") - .await; + return self.protocol_error(1, "OpenUni with wrong initiator").await; + } + + { + let streams = self.streams.lock().await; + if streams.contains_key(&id) { + return self.protocol_error(1, "duplicate stream open").await; + } } - + { let streams = self.streams.lock().await; + if streams.len() >= self.limits.max_open_streams { + return self.protocol_error(3, "too many open streams").await; + } if streams.contains_key(&id) { return self.protocol_error(1, "duplicate stream open").await; } @@ -367,25 +420,42 @@ impl SessionInner { let tx = { self.accept_uni_tx.lock().await.clone() }; if let Some(tx) = tx { - tx.send(recv).await.map_err(|_| Error::Closed)?; + match tx.try_send(recv) { + Ok(()) => {} + Err(mpsc::error::TrySendError::Full(_)) => { + // Application is not accepting inbound streams fast enough. + // Reset this stream and keep the connection alive. + let mut streams = self.streams.lock().await; + streams.remove(&id); + self.try_send_frame(Frame::ResetStream { id, code: 3 }); + return Ok(()); + } + Err(mpsc::error::TrySendError::Closed(_)) => return Err(Error::Closed), + } } else { return Err(Error::Closed); } } Frame::OpenBi { id } => { if id.dir() != StreamDir::Bi { - return self - .protocol_error(1, "OpenBi with non-bi StreamId") - .await; + return self.protocol_error(1, "OpenBi with non-bi StreamId").await; } if id.initiator_is_server() != self.peer_is_server() { - return self - .protocol_error(1, "OpenBi with wrong initiator") - .await; + return self.protocol_error(1, "OpenBi with wrong initiator").await; + } + + { + let streams = self.streams.lock().await; + if streams.contains_key(&id) { + return self.protocol_error(1, "duplicate stream open").await; + } } { let streams = self.streams.lock().await; + if streams.len() >= self.limits.max_open_streams { + return self.protocol_error(3, "too many open streams").await; + } if streams.contains_key(&id) { return self.protocol_error(1, "duplicate stream open").await; } @@ -396,12 +466,24 @@ impl SessionInner { let tx = { self.accept_bi_tx.lock().await.clone() }; if let Some(tx) = tx { - tx.send((send, recv)).await.map_err(|_| Error::Closed)?; + match tx.try_send((send, recv)) { + Ok(()) => {} + Err(mpsc::error::TrySendError::Full(_)) => { + let mut streams = self.streams.lock().await; + streams.remove(&id); + self.try_send_frame(Frame::ResetStream { id, code: 3 }); + return Ok(()); + } + Err(mpsc::error::TrySendError::Closed(_)) => return Err(Error::Closed), + } } else { return Err(Error::Closed); } } Frame::Stream { id, data, fin } => { + if data.len() > self.limits.max_stream_data_per_frame { + return self.protocol_error(2, "stream data too large").await; + } let mut streams = self.streams.lock().await; let Some(tx) = streams.get(&id) else { @@ -410,7 +492,20 @@ impl SessionInner { .await; }; - let _ = tx.send(RecvEvent { data, fin }).await; + match tx.try_send(RecvEvent { data, fin }) { + Ok(()) => {} + Err(mpsc::error::TrySendError::Full(_)) => { + // Receiver is too slow. Reset just this stream to prevent + // unbounded buffering or blocking the inbound loop. + streams.remove(&id); + self.try_send_frame(Frame::ResetStream { id, code: 3 }); + return Ok(()); + } + Err(mpsc::error::TrySendError::Closed(_)) => { + streams.remove(&id); + return Ok(()); + } + } if fin { streams.remove(&id); @@ -419,9 +514,7 @@ impl SessionInner { Frame::ResetStream { id, .. } | Frame::StopSending { id, .. } => { let mut streams = self.streams.lock().await; if streams.remove(&id).is_none() { - return self - .protocol_error(1, "reset/stop on unknown stream") - .await; + return self.protocol_error(1, "reset/stop on unknown stream").await; } } Frame::ConnectionClose { .. } => { @@ -433,7 +526,7 @@ impl SessionInner { } pub(crate) async fn register_recv_stream(self: &Arc, id: StreamId) -> RecvStream { - let (tx, rx) = mpsc::channel(128); + let (tx, rx) = mpsc::channel(self.limits.recv_event_queue_len); let mut streams = self.streams.lock().await; streams.insert(id, tx); RecvStream::new(id, self.clone(), rx) @@ -463,7 +556,9 @@ impl SessionInner { return; } let data = frame.encode().freeze(); - let _ = self.outbound_tx.try_send(tungstenite::Message::Binary(data)); + let _ = self + .outbound_tx + .try_send(tungstenite::Message::Binary(data)); } pub(crate) async fn close_all(&self) { From 1824f92ee69737deb87bf6c1a3a61d3d293c1483 Mon Sep 17 00:00:00 2001 From: shellrow Date: Sat, 17 Jan 2026 16:26:32 +0900 Subject: [PATCH 07/16] Update docs --- websock-tungstenite-mux/src/lib.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/websock-tungstenite-mux/src/lib.rs b/websock-tungstenite-mux/src/lib.rs index 3b2bc0e..6cc0aca 100644 --- a/websock-tungstenite-mux/src/lib.rs +++ b/websock-tungstenite-mux/src/lib.rs @@ -1,7 +1,6 @@ //! Tokio + tokio-tungstenite based WebSocket multiplexing transport. //! //! This crate provides a QUIC/WebTransport-like logical stream interface over a single WebSocket. -//! The wire format (frames, StreamId, VarInt, etc.) lives in `websock-mux-proto`. mod builder; mod client; From ccc86bc59783207146b83d4cf8ad5022722c01bf Mon Sep 17 00:00:00 2001 From: shellrow Date: Sat, 17 Jan 2026 21:56:07 +0900 Subject: [PATCH 08/16] Add dev-cert for testing --- devcert/.gitignore | 4 ++++ devcert/generate.sh | 50 ++++++++++++++++++++++++++++++++++++++++++++ devcert/openssl.conf | 22 +++++++++++++++++++ 3 files changed, 76 insertions(+) create mode 100644 devcert/.gitignore create mode 100755 devcert/generate.sh create mode 100644 devcert/openssl.conf diff --git a/devcert/.gitignore b/devcert/.gitignore new file mode 100644 index 0000000..89a9fbc --- /dev/null +++ b/devcert/.gitignore @@ -0,0 +1,4 @@ +*.crt +*.hex +*.key +*.fingerprint diff --git a/devcert/generate.sh b/devcert/generate.sh new file mode 100755 index 0000000..bb60067 --- /dev/null +++ b/devcert/generate.sh @@ -0,0 +1,50 @@ +#!/usr/bin/env bash +set -euo pipefail + +cd "$(dirname "${BASH_SOURCE[0]}")" + +CERT_NAME="localhost" +DAYS=10 +CONF="openssl.conf" + +echo "==> Generating self-signed certificate (${CERT_NAME})" +echo " validity: ${DAYS} days" +echo " config: ${CONF}" + +# Generate ECDSA P-256 private key +openssl ecparam \ + -genkey \ + -name prime256v1 \ + -out "${CERT_NAME}.key" + +# Generate self-signed certificate +openssl req \ + -x509 \ + -sha256 \ + -nodes \ + -days "${DAYS}" \ + -key "${CERT_NAME}.key" \ + -out "${CERT_NAME}.crt" \ + -config "${CONF}" \ + -extensions v3_req + +# Generate raw SHA-256 hash (DER -> hex, no colons) +openssl x509 \ + -in "${CERT_NAME}.crt" \ + -outform der \ +| openssl dgst -sha256 -binary \ +| xxd -p -c 256 \ +> "${CERT_NAME}.hex" + +# Also print human-readable fingerprint +openssl x509 \ + -in "${CERT_NAME}.crt" \ + -noout \ + -fingerprint -sha256 \ +> "${CERT_NAME}.fingerprint" + +echo "==> Done" +echo " - ${CERT_NAME}.crt" +echo " - ${CERT_NAME}.key" +echo " - ${CERT_NAME}.hex" +echo " - ${CERT_NAME}.fingerprint" diff --git a/devcert/openssl.conf b/devcert/openssl.conf new file mode 100644 index 0000000..fa64821 --- /dev/null +++ b/devcert/openssl.conf @@ -0,0 +1,22 @@ +[req] +distinguished_name = req_distinguished_name +req_extensions = req_ext +x509_extensions = v3_req +prompt = no + +[req_distinguished_name] +CN = localhost + +[req_ext] +subjectAltName = @alt_names + +[v3_req] +basicConstraints = CA:FALSE +keyUsage = digitalSignature, keyEncipherment +extendedKeyUsage = serverAuth +subjectAltName = @alt_names + +[alt_names] +DNS.1 = localhost +IP.1 = 127.0.0.1 +IP.2 = ::1 From 7d0f5f32d2114e3b1ee7f83c0966b4081f29ec9f Mon Sep 17 00:00:00 2001 From: shellrow Date: Sat, 17 Jan 2026 21:56:43 +0900 Subject: [PATCH 09/16] Update subprotocol name for mux --- websock-mux-proto/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/websock-mux-proto/src/lib.rs b/websock-mux-proto/src/lib.rs index 94dbd2e..c4dc2d7 100644 --- a/websock-mux-proto/src/lib.rs +++ b/websock-mux-proto/src/lib.rs @@ -4,4 +4,4 @@ pub mod varint; pub use stream::{Frame, FrameDecodeError, StreamDir, StreamId}; pub use varint::{VarInt, VarIntBoundsExceeded, VarIntUnexpectedEnd}; -pub const SUBPROTOCOL: &str = "websock-mux/1"; +pub const SUBPROTOCOL: &str = "websock-mux-1"; From af3e57a7423de6cbeaf3b7149460bf1eb7309e4f Mon Sep 17 00:00:00 2001 From: shellrow Date: Sat, 17 Jan 2026 21:59:15 +0900 Subject: [PATCH 10/16] Improve ALPN handling --- websock-mux/examples/echo-client-mux.rs | 14 ++--- websock-mux/examples/echo-server-mux.rs | 6 +- websock-proto/src/lib.rs | 2 +- websock-proto/src/options.rs | 7 +++ websock-tungstenite-mux/src/builder.rs | 80 +++++++++++++++++++++---- websock-tungstenite/src/builder.rs | 72 ++++++++++++++++++---- 6 files changed, 147 insertions(+), 34 deletions(-) diff --git a/websock-mux/examples/echo-client-mux.rs b/websock-mux/examples/echo-client-mux.rs index 0ffaff7..b8a38cd 100644 --- a/websock-mux/examples/echo-client-mux.rs +++ b/websock-mux/examples/echo-client-mux.rs @@ -4,7 +4,7 @@ use clap::Parser; use rustls::{RootCertStore, client::ClientConfig}; -use std::{path, sync::Arc}; +use std::path; use tracing::Level; use tracing_subscriber::FmtSubscriber; use url::Url; @@ -80,7 +80,7 @@ fn build_client(url: &Url, args: &Args) -> anyhow::Result { let tls_cfg: ClientConfig = if args.tls_disable_verify { tracing::warn!("disabling TLS certificate verification"); TlsClientConfigBuilder::new_insecure()? - .with_alpn_protocols(vec![b"h3".to_vec()]) + .with_alpn_protocols(vec![b"http/1.1".to_vec(), b"h2".to_vec()]) .build() } else if let Some(path) = &args.tls_cert { let certs = tls::cert::load_certs(path)?; @@ -94,7 +94,7 @@ fn build_client(url: &Url, args: &Args) -> anyhow::Result { let mut cfg = ClientConfig::builder() .with_root_certificates(roots) .with_no_client_auth(); - cfg.alpn_protocols = vec![b"h3".to_vec()]; + cfg.alpn_protocols = vec![b"http/1.1".to_vec(), b"h2".to_vec()]; cfg } else if is_local { tracing::warn!( @@ -103,17 +103,15 @@ fn build_client(url: &Url, args: &Args) -> anyhow::Result { url ); TlsClientConfigBuilder::new_insecure()? - .with_alpn_protocols(vec![b"h3".to_vec()]) + .with_alpn_protocols(vec![b"http/1.1".to_vec(), b"h2".to_vec()]) .build() } else { TlsClientConfigBuilder::new_with_native_certs()? - .with_alpn_protocols(vec![b"h3".to_vec()]) + .with_alpn_protocols(vec![b"http/1.1".to_vec(), b"h2".to_vec()]) .build() }; - Ok(ClientBuilder::new() - .with_tls_config(Arc::new(tls_cfg)) - .build()) + Ok(ClientBuilder::new().with_tls_config(tls_cfg).build()) } /// Determine whether a URL points to a loopback host. diff --git a/websock-mux/examples/echo-server-mux.rs b/websock-mux/examples/echo-server-mux.rs index 606f970..d3b0c3c 100644 --- a/websock-mux/examples/echo-server-mux.rs +++ b/websock-mux/examples/echo-server-mux.rs @@ -36,7 +36,7 @@ async fn main() -> anyhow::Result<()> { let args = Args::parse(); - let mut builder = ServerBuilder::new(); + let mut builder = ServerBuilder::new().with_addr(args.addr); match (&args.tls_cert, &args.tls_key) { (Some(cert_path), Some(key_path)) => { @@ -45,7 +45,7 @@ async fn main() -> anyhow::Result<()> { .with_no_client_auth() .with_single_cert(chain, key) .map_err(|e| anyhow::anyhow!(e.to_string()))?; - cfg.alpn_protocols = vec![b"h3".to_vec()]; + cfg.alpn_protocols = vec![b"http/1.1".to_vec(), b"h2".to_vec()]; builder = builder.with_tls_config(cfg); tracing::info!("TLS enabled (wss://)"); } @@ -55,7 +55,7 @@ async fn main() -> anyhow::Result<()> { _ => anyhow::bail!("both --tls-cert and --tls-key must be provided together, or neither"), } - let server: Server = builder.build(args.addr).await?; + let server: Server = builder.build().await?; let scheme = if args.tls_cert.is_some() { "wss" } else { "ws" }; tracing::info!("listening on {} ({}://{})", args.addr, scheme, args.addr); diff --git a/websock-proto/src/lib.rs b/websock-proto/src/lib.rs index 736fc21..f44f5da 100644 --- a/websock-proto/src/lib.rs +++ b/websock-proto/src/lib.rs @@ -10,4 +10,4 @@ mod options; pub use bytes::Bytes; pub use error::{Error, Result}; pub use message::{CloseFrame, Message}; -pub use options::{ConnectOptions, ServerOptions}; +pub use options::{ConnectOptions, ServerOptions, default_ws_alpn}; diff --git a/websock-proto/src/options.rs b/websock-proto/src/options.rs index 31459b0..187e372 100644 --- a/websock-proto/src/options.rs +++ b/websock-proto/src/options.rs @@ -1,3 +1,6 @@ +pub const ALPN_HTTP_1_1: &[u8] = b"http/1.1"; +pub const ALPN_H2: &[u8] = b"h2"; + /// Connection configuration shared by native and WebAssembly transports. #[derive(Debug, Clone, Default)] pub struct ConnectOptions { @@ -17,3 +20,7 @@ pub struct ServerOptions { /// Additional response headers (native only). pub headers: Vec<(String, String)>, } + +pub fn default_ws_alpn() -> Vec> { + vec![ALPN_HTTP_1_1.to_vec(), ALPN_H2.to_vec()] +} diff --git a/websock-tungstenite-mux/src/builder.rs b/websock-tungstenite-mux/src/builder.rs index 9ceef51..c451308 100644 --- a/websock-tungstenite-mux/src/builder.rs +++ b/websock-tungstenite-mux/src/builder.rs @@ -1,6 +1,7 @@ use rustls::{ClientConfig, ServerConfig}; +use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; use std::sync::Arc; -use websock_proto::Result; +use websock_proto::{Result, default_ws_alpn}; use crate::session::Limits; use crate::{Client, Server, bind}; @@ -10,7 +11,8 @@ use websock_mux_proto::SUBPROTOCOL; #[derive(Debug, Clone)] pub struct ClientBuilder { pub(crate) opts: websock_proto::ConnectOptions, - pub(crate) tls: Option>, + pub(crate) tls: Option, + pub(crate) alpn: Option>>, pub(crate) limits: Limits, } @@ -27,6 +29,7 @@ impl ClientBuilder { Self { opts, tls: None, + alpn: None, limits: Limits::default(), } } @@ -40,11 +43,21 @@ impl ClientBuilder { } /// Configure a custom rustls client config (for `wss://`). - pub fn with_tls_config(mut self, tls: Arc) -> Self { + pub fn with_tls_config(mut self, tls: ClientConfig) -> Self { self.tls = Some(tls); self } + pub fn with_default_alpn(mut self) -> Self { + self.alpn = Some(default_ws_alpn()); + self + } + + pub fn with_alpn_protocols(mut self, alpn: Vec>) -> Self { + self.alpn = Some(alpn); + self + } + /// Configure session limits pub fn with_limits(mut self, limits: Limits) -> Self { self.limits = limits; @@ -61,11 +74,21 @@ impl ClientBuilder { self } - pub fn build(self) -> Client { + fn build_tls_config(&self) -> Option> { + let mut cfg = self.tls.clone()?; + + if let Some(alpn) = &self.alpn { + cfg.alpn_protocols = alpn.clone(); + } + + Some(Arc::new(cfg)) + } + + pub fn build(&self) -> Client { Client { - opts: self.opts, - tls: self.tls, - limits: self.limits, + opts: self.opts.clone(), + tls: self.build_tls_config(), + limits: self.limits.clone(), } } } @@ -73,8 +96,10 @@ impl ClientBuilder { /// Builder for a mux WebSocket server. #[derive(Debug, Clone)] pub struct ServerBuilder { + pub(crate) addr: SocketAddr, pub(crate) opts: websock_proto::ServerOptions, pub(crate) tls: Option, + pub(crate) alpn: Option>>, pub(crate) limits: Limits, } @@ -83,12 +108,20 @@ impl ServerBuilder { let mut opts = websock_proto::ServerOptions::default(); opts.protocols.push(SUBPROTOCOL.to_string()); Self { + addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)), opts, tls: None, + alpn: None, limits: Limits::default(), } } + /// Set the bind address. + pub fn with_addr(mut self, addr: impl Into) -> Self { + self.addr = addr.into(); + self + } + pub fn with_options(mut self, opts: websock_proto::ServerOptions) -> Self { self.opts = opts; if !self.opts.protocols.iter().any(|p| p == SUBPROTOCOL) { @@ -103,6 +136,16 @@ impl ServerBuilder { self } + pub fn with_default_alpn(mut self) -> Self { + self.alpn = Some(default_ws_alpn()); + self + } + + pub fn with_alpn_protocols(mut self, alpn: Vec>) -> Self { + self.alpn = Some(alpn); + self + } + /// Configure session limits pub fn with_limits(mut self, limits: Limits) -> Self { self.limits = limits; @@ -119,10 +162,23 @@ impl ServerBuilder { self } - pub async fn build(self, addr: A) -> Result - where - A: tokio::net::ToSocketAddrs, - { - bind(addr, self.opts, self.tls, self.limits).await + fn build_tls_config(&self) -> Option { + let mut cfg = self.tls.clone()?; + + if let Some(alpn) = &self.alpn { + cfg.alpn_protocols = alpn.clone(); + } + + Some(cfg) + } + + pub async fn build(&self) -> Result { + bind( + self.addr, + self.opts.clone(), + self.build_tls_config(), + self.limits.clone(), + ) + .await } } diff --git a/websock-tungstenite/src/builder.rs b/websock-tungstenite/src/builder.rs index b90e339..91663c5 100644 --- a/websock-tungstenite/src/builder.rs +++ b/websock-tungstenite/src/builder.rs @@ -1,19 +1,21 @@ //! Builders for clients and servers using the Tokio Tungstenite transport. +use crate::{Connection, Server}; use rustls::pki_types::{CertificateDer, PrivateKeyDer}; use rustls::{ClientConfig, RootCertStore, ServerConfig}; use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; use std::sync::Arc; +use websock_proto::default_ws_alpn; use websock_proto::{ConnectOptions, Error, Result, ServerOptions}; -use crate::{Connection, Server}; - /// Builder for creating a WebSocket client. /// /// The resulting client can be reused for multiple `connect()` calls. #[derive(Debug, Clone)] pub struct ClientBuilder { opts: ConnectOptions, + tls: Option, + alpn: Option>>, } impl Default for ClientBuilder { @@ -27,6 +29,8 @@ impl ClientBuilder { pub fn new() -> Self { Self { opts: ConnectOptions::default(), + tls: None, + alpn: None, } } @@ -78,12 +82,10 @@ impl ClientBuilder { self } - /// Build a reusable client without custom TLS configuration. - pub fn build(self) -> Client { - Client { - opts: self.opts, - tls: None, - } + /// Configure a custom rustls client config (for `wss://`). + pub fn with_tls_config(mut self, tls: ClientConfig) -> Self { + self.tls = Some(tls); + self } /// Build a client configured with the system trust store. @@ -120,6 +122,34 @@ impl ClientBuilder { pub fn dangerous(self) -> DangerousClientBuilder { DangerousClientBuilder { opts: self.opts } } + + pub fn with_default_alpn(mut self) -> Self { + self.alpn = Some(default_ws_alpn()); + self + } + + pub fn with_alpn_protocols(mut self, alpn: Vec>) -> Self { + self.alpn = Some(alpn); + self + } + + fn build_tls_config(&self) -> Option> { + let mut cfg = self.tls.clone()?; + + if let Some(alpn) = &self.alpn { + cfg.alpn_protocols = alpn.clone(); + } + + Some(Arc::new(cfg)) + } + + /// Build a client. + pub fn build(&self) -> Client { + Client { + opts: self.opts.clone(), + tls: self.build_tls_config(), + } + } } /// Reusable WebSocket client created by [`ClientBuilder`]. @@ -163,6 +193,7 @@ pub struct ServerBuilder { addr: SocketAddr, opts: ServerOptions, tls: Option, + alpn: Option>>, } impl Default for ServerBuilder { @@ -177,6 +208,7 @@ impl ServerBuilder { Self { addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)), opts: ServerOptions::default(), + alpn: None, tls: None, } } @@ -230,8 +262,28 @@ impl ServerBuilder { self } + pub fn with_default_alpn(mut self) -> Self { + self.alpn = Some(default_ws_alpn()); + self + } + + pub fn with_alpn_protocols(mut self, alpn: Vec>) -> Self { + self.alpn = Some(alpn); + self + } + + fn build_tls_config(&self) -> Option { + let mut cfg = self.tls.clone()?; + + if let Some(alpn) = &self.alpn { + cfg.alpn_protocols = alpn.clone(); + } + + Some(cfg) + } + /// Bind the listener and return a server instance. - pub async fn build(self) -> Result { - crate::server::bind(&self.addr, self.opts, self.tls).await + pub async fn build(&self) -> Result { + crate::server::bind(self.addr, self.opts.clone(), self.build_tls_config()).await } } From 402f8253f0fd630e30184f0d22b04907b508067e Mon Sep 17 00:00:00 2001 From: shellrow Date: Sat, 17 Jan 2026 21:59:41 +0900 Subject: [PATCH 11/16] Impl WASM mux client --- websock-mux-proto/src/varint.rs | 6 +- websock-wasm-demo/Cargo.toml | 1 + websock-wasm-demo/index.html | 1 + websock-wasm-demo/src/echo.rs | 60 ++++ websock-wasm-demo/src/echo_mux.rs | 39 +++ websock-wasm-demo/src/lib.rs | 75 ++-- websock-wasm-mux/Cargo.toml | 13 + websock-wasm-mux/src/builder.rs | 86 +++++ websock-wasm-mux/src/client.rs | 24 ++ websock-wasm-mux/src/lib.rs | 15 +- websock-wasm-mux/src/session.rs | 560 ++++++++++++++++++++++++++++++ websock-wasm/src/connection.rs | 15 +- 12 files changed, 834 insertions(+), 61 deletions(-) create mode 100644 websock-wasm-demo/src/echo.rs create mode 100644 websock-wasm-demo/src/echo_mux.rs create mode 100644 websock-wasm-mux/src/builder.rs create mode 100644 websock-wasm-mux/src/client.rs create mode 100644 websock-wasm-mux/src/session.rs diff --git a/websock-mux-proto/src/varint.rs b/websock-mux-proto/src/varint.rs index 9d294c7..cc508b7 100644 --- a/websock-mux-proto/src/varint.rs +++ b/websock-mux-proto/src/varint.rs @@ -3,10 +3,14 @@ // Based on Quinn: https://github.com/quinn-rs/quinn/tree/main/quinn-proto/src // Licensed under Apache-2.0 OR MIT -use std::{convert::TryInto, fmt, io::Cursor}; +use std::{convert::TryInto, fmt}; + +#[cfg(not(target_arch = "wasm32"))] +use std::io::Cursor; use bytes::{Buf, BufMut}; use thiserror::Error; + #[cfg(not(target_arch = "wasm32"))] use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; diff --git a/websock-wasm-demo/Cargo.toml b/websock-wasm-demo/Cargo.toml index b10e7dd..f71516e 100644 --- a/websock-wasm-demo/Cargo.toml +++ b/websock-wasm-demo/Cargo.toml @@ -10,6 +10,7 @@ crate-type = ["cdylib"] [dependencies] websock = { path = "../websock" } +websock-mux = { path = "../websock-mux" } wasm-bindgen = "0.2" wasm-bindgen-futures = "0.4" console_error_panic_hook = "0.1" diff --git a/websock-wasm-demo/index.html b/websock-wasm-demo/index.html index 4737082..881e4d7 100644 --- a/websock-wasm-demo/index.html +++ b/websock-wasm-demo/index.html @@ -7,6 +7,7 @@ +

Open DevTools Console to see logs.

\ No newline at end of file diff --git a/websock-wasm-demo/src/echo.rs b/websock-wasm-demo/src/echo.rs new file mode 100644 index 0000000..931127c --- /dev/null +++ b/websock-wasm-demo/src/echo.rs @@ -0,0 +1,60 @@ +use futures_util::{SinkExt, StreamExt}; +use websock::{ClientBuilder, Message}; + +pub const DEFAULT_ECHO_URL: &str = "wss://echo.websocket.org"; + +pub async fn run_conn_demo(url: &str, log: impl Fn(&str)) { + log("[conn demo] start"); + + let client = ClientBuilder::new() + .with_system_roots() + .expect("client config failed"); + + log(&format!("[conn demo] Connecting to server at {url}...")); + log(&format!( + "[conn demo] Using WebSocket options: {:?}", + client.options() + )); + + let mut conn = client.connect(url).await.expect("connect failed"); + + conn.send(Message::Text("hello from wasm".into())) + .await + .expect("send failed"); + + let msg = conn.recv().await.expect("recv failed"); + log(&format!("[conn demo] got: {msg:?}")); + + conn.close().await.expect("close failed"); + log("[conn demo] done"); +} + +pub async fn run_split_demo(url: &str, log: impl Fn(&str)) { + log("[split demo] start"); + + let client = ClientBuilder::new() + .with_system_roots() + .expect("client config failed"); + + log(&format!("[split demo] Connecting to server at {url}...")); + log(&format!( + "[split demo] Using WebSocket options: {:?}", + client.options() + )); + + let conn = client.connect(url).await.expect("connect failed"); + let (mut sink, mut stream) = websock::stream::split(conn); + + sink.send(Message::Text("hello from wasm split".into())) + .await + .expect("send failed"); + + let msg = stream + .next() + .await + .expect("stream closed") + .expect("recv failed"); + + log(&format!("[split demo] got: {msg:?}")); + log("[split demo] done"); +} diff --git a/websock-wasm-demo/src/echo_mux.rs b/websock-wasm-demo/src/echo_mux.rs new file mode 100644 index 0000000..6fd5dc6 --- /dev/null +++ b/websock-wasm-demo/src/echo_mux.rs @@ -0,0 +1,39 @@ +pub const DEFAULT_MUX_URL: &str = "wss://localhost:9001"; + +pub async fn run_mux_bi_demo(url: &str, log: impl Fn(&str)) { + log("[mux bi demo] start"); + + let client = websock_mux::ClientBuilder::new() + .with_system_roots() + .expect("client config failed"); + + log(&format!( + "[mux bi demo] Connecting to mux server at {url}..." + )); + log(&format!( + "[mux bi demo] Using WebSocket options: {:?}", + client.options() + )); + + // Mux Session + let session = match client.connect(url).await { + Ok(s) => s, + Err(e) => { + log(&format!("[mux bi demo] connect failed: {e:?}")); + return; + } + }; + + let (send, mut recv) = session.open_bi().expect("open_bi failed"); + + send.write(b"hello mux from wasm").expect("send failed"); + send.finish().expect("finish failed"); + + let mut buf = vec![0u8; 1024]; + while let Some(n) = recv.read(&mut buf).await.expect("read failed") { + let text = String::from_utf8_lossy(&buf[..n]); + log(&format!("[mux bi demo] recv: {}", text)); + } + + log("[mux bi demo] done"); +} diff --git a/websock-wasm-demo/src/lib.rs b/websock-wasm-demo/src/lib.rs index 595ddb4..7123840 100644 --- a/websock-wasm-demo/src/lib.rs +++ b/websock-wasm-demo/src/lib.rs @@ -1,15 +1,14 @@ //! Minimal WebAssembly demo that exercises the websock API in the browser. +#![cfg(all(target_arch = "wasm32", not(target_os = "wasi")))] + +mod echo; +mod echo_mux; + use wasm_bindgen::prelude::*; use wasm_bindgen_futures::spawn_local; use web_sys::{Document, HtmlButtonElement}; -use futures_util::{SinkExt, StreamExt}; -use websock::{ClientBuilder, Message}; - -// Alternate echo server used for quick testing. -const DEFAULT_ECHO_URL: &str = "wss://echo.websocket.org"; - /// Obtain the active document. fn document() -> Document { web_sys::window().unwrap().document().unwrap() @@ -39,65 +38,39 @@ fn hook_button(id: &str, f: impl Fn() + 'static) { cb.forget(); } +fn prompt_url(default_url: &str) -> String { + let w = web_sys::window().unwrap(); + match w.prompt_with_message_and_default("URL?", default_url) { + Ok(Some(s)) if !s.trim().is_empty() => s, + _ => default_url.to_string(), + } +} + #[wasm_bindgen(start)] /// Entry point invoked by the browser when the module loads. pub fn start() { console_error_panic_hook::set_once(); hook_button("btn-conn", || { + let url = prompt_url(echo::DEFAULT_ECHO_URL); spawn_local(async move { - log("[conn demo] start"); - - let client = ClientBuilder::new() - .with_system_roots() - .expect("client config failed"); - - let mut conn = client - .connect(DEFAULT_ECHO_URL) - .await - .expect("connect failed"); - - conn.send(Message::Text("hello from wasm".into())) - .await - .expect("send failed"); - - let msg = conn.recv().await.expect("recv failed"); - log(&format!("[conn demo] got: {msg:?}")); - - conn.close().await.expect("close failed"); - log("[conn demo] done"); + echo::run_conn_demo(&url, log).await; }); }); hook_button("btn-split", || { + let url = prompt_url(echo::DEFAULT_ECHO_URL); spawn_local(async move { - log("[split demo] start"); - - let client = ClientBuilder::new() - .with_system_roots() - .expect("client config failed"); - - let conn = client - .connect(DEFAULT_ECHO_URL) - .await - .expect("connect failed"); - - let (mut sink, mut stream) = websock::stream::split(conn); - - sink.send(Message::Text("hello from wasm split".into())) - .await - .expect("send failed"); - - let msg = stream - .next() - .await - .expect("stream closed") - .expect("recv failed"); + echo::run_split_demo(&url, log).await; + }); + }); - log(&format!("[split demo] got: {msg:?}")); - log("[split demo] done"); + hook_button("btn-mux-bi", || { + let url = prompt_url(echo_mux::DEFAULT_MUX_URL); + spawn_local(async move { + echo_mux::run_mux_bi_demo(&url, log).await; }); }); - log("ready: click [conn demo] or [split demo]"); + log("ready: click [conn demo] / [split demo] / [mux bi demo]"); } diff --git a/websock-wasm-mux/Cargo.toml b/websock-wasm-mux/Cargo.toml index e413eab..b8e3ad6 100644 --- a/websock-wasm-mux/Cargo.toml +++ b/websock-wasm-mux/Cargo.toml @@ -3,5 +3,18 @@ name = "websock-wasm-mux" version.workspace = true edition.workspace = true authors.workspace = true +description = "WebAssembly WebSocket multiplexing transport (WebTransport-like streams over WebSocket)." +repository = "https://github.com/foctal/websock" +readme = "../README.md" +keywords = ["network", "websocket", "multiplex", "wasm"] +categories = ["network-programming", "web-programming"] +license = "MIT" [dependencies] +websock-proto = { workspace = true } +websock-mux-proto = { workspace = true } +websock-wasm = { workspace = true } +bytes = { workspace = true } +futures-channel = "0.3" +futures-util = { version = "0.3", default-features = false, features = ["alloc"] } +wasm-bindgen-futures = "0.4" diff --git a/websock-wasm-mux/src/builder.rs b/websock-wasm-mux/src/builder.rs new file mode 100644 index 0000000..4cb330e --- /dev/null +++ b/websock-wasm-mux/src/builder.rs @@ -0,0 +1,86 @@ +//! Builder for browser mux clients. + +use crate::{Client, Limits}; +use websock_mux_proto::SUBPROTOCOL; +use websock_proto::{ConnectOptions, Result}; + +/// Builder for creating a browser WebSocket mux client. +/// +/// This wraps `websock-wasm` and enforces the mux `SUBPROTOCOL`. +#[derive(Debug, Clone)] +pub struct ClientBuilder { + opts: ConnectOptions, + limits: Limits, +} + +impl Default for ClientBuilder { + fn default() -> Self { + Self::new() + } +} + +impl ClientBuilder { + /// Create a new builder with default options. + pub fn new() -> Self { + let mut opts = ConnectOptions::default(); + opts.protocols.push(SUBPROTOCOL.to_string()); + Self { + opts, + limits: Limits::default(), + } + } + + /// Replace the builder options wholesale. + /// + /// Note: this will re-append the mux subprotocol if missing. + pub fn with_options(mut self, mut opts: ConnectOptions) -> Self { + if !opts.protocols.iter().any(|p| p == SUBPROTOCOL) { + opts.protocols.push(SUBPROTOCOL.to_string()); + } + self.opts = opts; + self + } + + /// Return a reference to the current options. + pub fn options(&self) -> &ConnectOptions { + &self.opts + } + + /// Add a single header to the connection request. + pub fn with_header(mut self, name: impl Into, value: impl Into) -> Self { + self.opts.headers.push((name.into(), value.into())); + self + } + + /// Add multiple headers to the connection request. + pub fn with_headers(mut self, headers: I) -> Self + where + I: IntoIterator, + K: Into, + V: Into, + { + for (k, v) in headers { + self.opts.headers.push((k.into(), v.into())); + } + self + } + + /// Configure session limits. + pub fn limits(mut self, limits: Limits) -> Self { + self.limits = limits; + self + } + + /// Build a reusable client. + pub fn build(self) -> Client { + Client { + opts: self.opts, + limits: self.limits, + } + } + + /// Build a client (no-op for browser cert roots). + pub fn with_system_roots(self) -> Result { + Ok(self.build()) + } +} diff --git a/websock-wasm-mux/src/client.rs b/websock-wasm-mux/src/client.rs new file mode 100644 index 0000000..02015d8 --- /dev/null +++ b/websock-wasm-mux/src/client.rs @@ -0,0 +1,24 @@ +//! Browser mux client. + +use crate::{Limits, Session}; +use websock_proto::{ConnectOptions, Result}; + +/// Reusable browser WebSocket mux client created by [`crate::ClientBuilder`]. +#[derive(Debug, Clone)] +pub struct Client { + pub(crate) opts: ConnectOptions, + pub(crate) limits: Limits, +} + +impl Client { + /// Return a reference to the configured connection options. + pub fn options(&self) -> &ConnectOptions { + &self.opts + } + + /// Establish a browser WebSocket connection and create a mux [`Session`]. + pub async fn connect(&self, url: &str) -> Result { + let conn = websock_wasm::connect(url, self.opts.clone()).await?; + Ok(Session::new(conn, self.limits.clone())) + } +} diff --git a/websock-wasm-mux/src/lib.rs b/websock-wasm-mux/src/lib.rs index d951046..e8e15f3 100644 --- a/websock-wasm-mux/src/lib.rs +++ b/websock-wasm-mux/src/lib.rs @@ -1,3 +1,12 @@ -pub fn todo() { - unimplemented!() -} +//! Browser (wasm32) WebSocket multiplexing transport. +//! +//! This crate provides a QUIC/WebTransport-like *logical stream* interface over a single +//! browser WebSocket connection. + +mod builder; +mod client; +mod session; + +pub use builder::ClientBuilder; +pub use client::Client; +pub use session::{Limits, RecvStream, SendStream, Session}; diff --git a/websock-wasm-mux/src/session.rs b/websock-wasm-mux/src/session.rs new file mode 100644 index 0000000..a14b469 --- /dev/null +++ b/websock-wasm-mux/src/session.rs @@ -0,0 +1,560 @@ +use std::cell::RefCell; +use std::collections::HashMap; +use std::rc::Rc; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; + +use bytes::{BufMut, Bytes}; +use futures_channel::mpsc; +use futures_util::lock::Mutex; +use futures_util::{FutureExt, StreamExt}; +use wasm_bindgen_futures::spawn_local; +use websock_proto::{Error, Message, Result}; + +use websock_mux_proto::{Frame, StreamDir, StreamId}; + +/// Session limits to prevent unbounded buffering / DoS. +#[derive(Debug, Clone)] +pub struct Limits { + /// Maximum size of a single WebSocket binary message accepted by the inbound loop. + pub max_ws_message_size: usize, + /// Maximum `Stream` frame payload size. + pub max_stream_data_per_frame: usize, + /// Maximum number of concurrently open receive streams. + pub max_open_streams: usize, + /// Per-stream receive event queue length. + pub recv_event_queue_len: usize, + /// Session outbound queue length. + pub outbound_queue_len: usize, + /// Queue length for accepting inbound uni streams. + pub accept_uni_queue_len: usize, + /// Queue length for accepting inbound bi streams. + pub accept_bi_queue_len: usize, +} + +impl Default for Limits { + fn default() -> Self { + Self { + max_ws_message_size: 1 * 1024 * 1024, + max_stream_data_per_frame: 256 * 1024, + max_open_streams: 1024, + recv_event_queue_len: 128, + outbound_queue_len: 256, + accept_uni_queue_len: 128, + accept_bi_queue_len: 128, + } + } +} + +#[derive(Clone)] +pub struct Session { + inner: Rc, + accept_uni: Rc>>, + accept_bi: Rc>>, +} + +impl Session { + pub(crate) fn new(conn: websock_wasm::Connection, limits: Limits) -> Self { + let (outbound_tx, outbound_rx) = mpsc::channel::(limits.outbound_queue_len); + let (accept_uni_tx, accept_uni_rx) = + mpsc::channel::(limits.accept_uni_queue_len); + let (accept_bi_tx, accept_bi_rx) = + mpsc::channel::<(SendStream, RecvStream)>(limits.accept_bi_queue_len); + + let inner = Rc::new(SessionInner::new( + limits, + outbound_tx, + accept_uni_tx, + accept_bi_tx, + )); + + let session = Self { + inner: inner.clone(), + accept_uni: Rc::new(Mutex::new(accept_uni_rx)), + accept_bi: Rc::new(Mutex::new(accept_bi_rx)), + }; + + inner.spawn_task(conn, outbound_rx); + session + } + + pub fn open_uni(&self) -> Result { + let id = self.inner.next_stream_id(StreamDir::Uni)?; + self.inner.send_frame(Frame::OpenUni { id })?; + Ok(SendStream::new(id, self.inner.clone())) + } + + pub fn open_bi(&self) -> Result<(SendStream, RecvStream)> { + let id = self.inner.next_stream_id(StreamDir::Bi)?; + let recv = self.inner.clone().register_recv_stream(id); + self.inner.send_frame(Frame::OpenBi { id })?; + Ok((SendStream::new(id, self.inner.clone()), recv)) + } + + pub async fn accept_uni(&self) -> Result { + let mut rx = self.accept_uni.lock().await; + rx.next().await.ok_or(Error::Closed) + } + + pub async fn accept_bi(&self) -> Result<(SendStream, RecvStream)> { + let mut rx = self.accept_bi.lock().await; + rx.next().await.ok_or(Error::Closed) + } +} + +#[derive(Clone)] +pub struct SendStream { + id: StreamId, + session: Rc, + finished: Rc, +} + +impl SendStream { + fn new(id: StreamId, session: Rc) -> Self { + Self { + id, + session, + finished: Rc::new(AtomicBool::new(false)), + } + } + + pub fn write(&self, data: &[u8]) -> Result<()> { + self.write_buf(Bytes::copy_from_slice(data)) + } + + pub fn write_buf(&self, data: Bytes) -> Result<()> { + self.session.send_frame(Frame::Stream { + id: self.id, + data, + fin: false, + }) + } + + pub fn write_all(&self, data: &[u8]) -> Result<()> { + self.write(data) + } + + pub fn finish(&self) -> Result<()> { + if self + .finished + .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst) + .is_ok() + { + self.session.send_frame(Frame::Stream { + id: self.id, + data: Bytes::new(), + fin: true, + })?; + } + Ok(()) + } + + pub fn reset(&self, code: u64) -> Result<()> { + self.finished.store(true, Ordering::SeqCst); + self.session + .send_frame(Frame::ResetStream { id: self.id, code }) + } + + pub fn closed(&self) -> bool { + self.finished.load(Ordering::SeqCst) + } +} + +impl Drop for SendStream { + fn drop(&mut self) { + if !self.finished.load(Ordering::SeqCst) { + let _ = self.session.try_send_frame(Frame::ResetStream { + id: self.id, + code: 0, + }); + } + } +} + +#[derive(Debug)] +struct RecvEvent { + data: Bytes, + fin: bool, +} + +pub struct RecvStream { + id: StreamId, + session: Rc, + receiver: mpsc::Receiver, + finished: bool, + pending: Bytes, +} + +impl RecvStream { + fn new(id: StreamId, session: Rc, receiver: mpsc::Receiver) -> Self { + Self { + id, + session, + receiver, + finished: false, + pending: Bytes::new(), + } + } + + pub async fn read(&mut self, buf: &mut [u8]) -> Result> { + if self.finished { + return Ok(None); + } + if self.pending.is_empty() { + if let Some(chunk) = self.read_chunk_internal().await? { + self.pending = chunk; + } else { + return Ok(None); + } + } + let amt = buf.len().min(self.pending.len()); + buf[..amt].copy_from_slice(&self.pending[..amt]); + self.pending = self.pending.slice(amt..); + Ok(Some(amt)) + } + + pub async fn read_buf(&mut self, buf: &mut B) -> Result> { + let mut temp = vec![0u8; 4096]; + let size_opt = self.read(&mut temp).await?; + if let Some(size) = size_opt { + buf.put_slice(&temp[..size]); + Ok(Some(size)) + } else { + Ok(None) + } + } + + async fn read_chunk_internal(&mut self) -> Result> { + match self.receiver.next().await { + Some(event) => { + if event.fin { + self.finished = true; + } + if event.data.is_empty() && event.fin { + Ok(None) + } else { + Ok(Some(event.data)) + } + } + None => { + self.finished = true; + Ok(None) + } + } + } + + pub async fn read_chunk(&mut self, max: usize) -> Result> { + match self.receiver.next().await { + Some(mut event) => { + if event.fin { + self.finished = true; + } + if event.data.is_empty() && event.fin { + Ok(None) + } else if event.data.len() > max { + let chunk = event.data.split_to(max); + self.pending = event.data; + Ok(Some(chunk)) + } else { + Ok(Some(event.data)) + } + } + None => { + self.finished = true; + Ok(None) + } + } + } + + pub fn stop(&self, code: u64) -> Result<()> { + self.session + .send_frame(Frame::StopSending { id: self.id, code }) + } + + pub fn closed(&self) -> bool { + self.finished + } +} + +impl Drop for RecvStream { + fn drop(&mut self) { + if !self.finished { + let _ = self.session.try_send_frame(Frame::StopSending { + id: self.id, + code: 0, + }); + } + } +} + +struct SessionInner { + limits: Limits, + outbound_tx: RefCell>, + accept_uni_tx: Mutex>>, + accept_bi_tx: Mutex>>, + streams: RefCell>>, + next_uni: AtomicU64, + next_bi: AtomicU64, + closed: AtomicBool, +} + +impl SessionInner { + fn new( + limits: Limits, + outbound_tx: mpsc::Sender, + accept_uni_tx: mpsc::Sender, + accept_bi_tx: mpsc::Sender<(SendStream, RecvStream)>, + ) -> Self { + Self { + limits, + outbound_tx: RefCell::new(outbound_tx), + accept_uni_tx: Mutex::new(Some(accept_uni_tx)), + accept_bi_tx: Mutex::new(Some(accept_bi_tx)), + streams: RefCell::new(HashMap::new()), + next_uni: AtomicU64::new(0), + next_bi: AtomicU64::new(0), + closed: AtomicBool::new(false), + } + } + + fn spawn_task( + self: Rc, + mut conn: websock_wasm::Connection, + mut outbound_rx: mpsc::Receiver, + ) { + let inner = self.clone(); + spawn_local(async move { + loop { + futures_util::select! { + msg = conn.recv().fuse() => { + match msg { + Ok(Message::Binary(data)) => { + if data.len() > inner.limits.max_ws_message_size { + let _ = inner.protocol_error(2, "ws message too large").await; + break; + } + let mut cursor = &data[..]; + let frame = match Frame::decode(&mut cursor) { + Ok(f) => f, + Err(_) => { + let _ = inner.protocol_error(1, "invalid frame").await; + break; + } + }; + if inner.clone().handle_frame(frame).await.is_err() { + break; + } + } + Ok(Message::Text(_)) => { + let _ = inner.protocol_error(1, "text message not supported").await; + break; + } + Err(_) => break, + } + } + out = outbound_rx.next().fuse() => { + match out { + Some(m) => { + if conn.send(m).await.is_err() { + break; + } + } + None => break, + } + } + } + } + + inner.close_all().await; + let _ = conn.close().await; + }); + } + + fn next_stream_id(&self, dir: StreamDir) -> Result { + let is_server = false; // browser is always client + let n = match dir { + StreamDir::Uni => self.next_uni.fetch_add(1, Ordering::SeqCst), + StreamDir::Bi => self.next_bi.fetch_add(1, Ordering::SeqCst), + }; + StreamId::new(n, is_server, dir) + .map_err(|e| Error::Protocol(format!("stream id overflow: {}", e))) + } + + async fn handle_frame(self: Rc, frame: Frame) -> Result<()> { + match frame { + Frame::OpenUni { id } => { + if id.dir() != StreamDir::Uni { + return self + .protocol_error(1, "OpenUni with non-uni StreamId") + .await; + } + // In browser wasm, we are always the client, so the peer is the server. + // Therefore, inbound streams must be server-initiated. + if !id.initiator_is_server() { + return self.protocol_error(1, "OpenUni with wrong initiator").await; + } + + let mut map = self.streams.borrow_mut(); + if map.len() >= self.limits.max_open_streams { + drop(map); + return self.protocol_error(3, "too many open streams").await; + } + if map.contains_key(&id) { + drop(map); + return self.protocol_error(1, "duplicate stream open").await; + } + + let recv = Self::register_recv_stream_locked(&self, &mut map, id); + drop(map); + + let tx = self.accept_uni_tx.lock().await.clone(); + if let Some(mut tx) = tx { + match tx.try_send(recv) { + Ok(()) => Ok(()), + Err(e) => { + if e.is_full() { + self.streams.borrow_mut().remove(&id); + let _ = self.try_send_frame(Frame::ResetStream { id, code: 3 }); + Ok(()) + } else { + Err(Error::Closed) + } + } + } + } else { + Err(Error::Closed) + } + } + Frame::OpenBi { id } => { + if id.dir() != StreamDir::Bi { + return self.protocol_error(1, "OpenBi with non-bi StreamId").await; + } + if !id.initiator_is_server() { + return self.protocol_error(1, "OpenBi with wrong initiator").await; + } + + let mut map = self.streams.borrow_mut(); + if map.len() >= self.limits.max_open_streams { + drop(map); + return self.protocol_error(3, "too many open streams").await; + } + if map.contains_key(&id) { + drop(map); + return self.protocol_error(1, "duplicate stream open").await; + } + + let recv = Self::register_recv_stream_locked(&self, &mut map, id); + drop(map); + + let send = SendStream::new(id, self.clone()); + let tx = self.accept_bi_tx.lock().await.clone(); + if let Some(mut tx) = tx { + match tx.try_send((send, recv)) { + Ok(()) => Ok(()), + Err(e) => { + if e.is_full() { + self.streams.borrow_mut().remove(&id); + let _ = self.try_send_frame(Frame::ResetStream { id, code: 3 }); + Ok(()) + } else { + Err(Error::Closed) + } + } + } + } else { + Err(Error::Closed) + } + } + Frame::Stream { id, data, fin } => { + if data.len() > self.limits.max_stream_data_per_frame { + return self.protocol_error(2, "stream data too large").await; + } + + let mut map = self.streams.borrow_mut(); + let Some(tx) = map.get_mut(&id) else { + drop(map); + return self + .protocol_error(1, "Stream data on unknown stream") + .await; + }; + + match tx.try_send(RecvEvent { data, fin }) { + Ok(()) => {} + Err(e) => { + if e.is_full() { + map.remove(&id); + drop(map); + let _ = self.try_send_frame(Frame::ResetStream { id, code: 3 }); + return Ok(()); + } else { + map.remove(&id); + return Ok(()); + } + } + } + + if fin { + map.remove(&id); + } + Ok(()) + } + Frame::ResetStream { id, .. } | Frame::StopSending { id, .. } => { + self.streams.borrow_mut().remove(&id); + Ok(()) + } + Frame::ConnectionClose { .. } => { + self.close_all().await; + Err(Error::Closed) + } + } + } + + fn register_recv_stream(self: Rc, id: StreamId) -> RecvStream { + let mut map = self.streams.borrow_mut(); + Self::register_recv_stream_locked(&self, &mut map, id) + } + + fn register_recv_stream_locked( + this: &Rc, + map: &mut HashMap>, + id: StreamId, + ) -> RecvStream { + let (tx, rx) = mpsc::channel(this.limits.recv_event_queue_len); + map.insert(id, tx); + RecvStream::new(id, this.clone(), rx) + } + + fn try_send_frame(&self, frame: Frame) -> std::result::Result<(), Error> { + let bytes = frame.encode().freeze(); + self.outbound_tx + .borrow_mut() + .try_send(Message::Binary(bytes)) + .map_err(|_| Error::Closed) + } + + fn send_frame(&self, frame: Frame) -> Result<()> { + self.try_send_frame(frame) + } + + async fn protocol_error(&self, code: u64, reason: &str) -> Result<()> { + let _ = self.try_send_frame(Frame::ConnectionClose { + code, + reason: reason.to_string(), + }); + self.close_all().await; + Err(Error::Protocol(reason.to_string())) + } + + async fn close_all(&self) { + if self + .closed + .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst) + .is_err() + { + return; + } + + self.streams.borrow_mut().clear(); + *self.accept_uni_tx.lock().await = None; + *self.accept_bi_tx.lock().await = None; + } +} diff --git a/websock-wasm/src/connection.rs b/websock-wasm/src/connection.rs index 87f7e60..7af5f6d 100644 --- a/websock-wasm/src/connection.rs +++ b/websock-wasm/src/connection.rs @@ -58,21 +58,24 @@ pub async fn connect(url: &str, opts: ConnectOptions) -> Result { ws.set_onclose(Some(wait_onclose.as_ref().unchecked_ref())); // Wait until the connection is opened or fails. - match open_rx.await { - Ok(Ok(())) => {} - Ok(Err(e)) => return Err(e), - Err(_) => return Err(Error::Other("onopen waiter dropped".into())), - } + let open_res = open_rx.await; - // Unset the connection process handlers. + // Always unset the connection process handlers. ws.set_onopen(None); ws.set_onerror(None); ws.set_onclose(None); + // Drop closures AFTER unsetting. drop(wait_onopen); drop(wait_onerror); drop(wait_onclose); + match open_res { + Ok(Ok(())) => {} + Ok(Err(e)) => return Err(e), + Err(_) => return Err(Error::Other("onopen waiter dropped".into())), + } + // Set up message/error/close handlers. let tx_msg = tx.clone(); let onmessage = From 300e1aacd43ad4b5be72dbf61859a23edd92d6ed Mon Sep 17 00:00:00 2001 From: shellrow Date: Sat, 17 Jan 2026 22:27:38 +0900 Subject: [PATCH 12/16] Update examples --- websock-mux/examples/echo-client-mux.rs | 20 ++++++-------------- websock-mux/examples/echo-server-mux.rs | 5 ++--- websock/examples/echo-server.rs | 2 +- 3 files changed, 9 insertions(+), 18 deletions(-) diff --git a/websock-mux/examples/echo-client-mux.rs b/websock-mux/examples/echo-client-mux.rs index b8a38cd..0886648 100644 --- a/websock-mux/examples/echo-client-mux.rs +++ b/websock-mux/examples/echo-client-mux.rs @@ -79,9 +79,7 @@ fn build_client(url: &Url, args: &Args) -> anyhow::Result { let tls_cfg: ClientConfig = if args.tls_disable_verify { tracing::warn!("disabling TLS certificate verification"); - TlsClientConfigBuilder::new_insecure()? - .with_alpn_protocols(vec![b"http/1.1".to_vec(), b"h2".to_vec()]) - .build() + TlsClientConfigBuilder::new_insecure()?.build() } else if let Some(path) = &args.tls_cert { let certs = tls::cert::load_certs(path)?; anyhow::ensure!(!certs.is_empty(), "could not find certificate"); @@ -91,27 +89,21 @@ fn build_client(url: &Url, args: &Args) -> anyhow::Result { let _ = roots.add(c); } - let mut cfg = ClientConfig::builder() + ClientConfig::builder() .with_root_certificates(roots) - .with_no_client_auth(); - cfg.alpn_protocols = vec![b"http/1.1".to_vec(), b"h2".to_vec()]; - cfg + .with_no_client_auth() } else if is_local { tracing::warn!( "no --tls-cert provided and target looks local ({}); \ using insecure mode for quick testing (equivalent to --tls-disable-verify)", url ); - TlsClientConfigBuilder::new_insecure()? - .with_alpn_protocols(vec![b"http/1.1".to_vec(), b"h2".to_vec()]) - .build() + TlsClientConfigBuilder::new_insecure()?.build() } else { - TlsClientConfigBuilder::new_with_native_certs()? - .with_alpn_protocols(vec![b"http/1.1".to_vec(), b"h2".to_vec()]) - .build() + TlsClientConfigBuilder::new_with_native_certs()?.build() }; - Ok(ClientBuilder::new().with_tls_config(tls_cfg).build()) + Ok(ClientBuilder::new().with_tls_config(tls_cfg).with_default_alpn().build()) } /// Determine whether a URL points to a loopback host. diff --git a/websock-mux/examples/echo-server-mux.rs b/websock-mux/examples/echo-server-mux.rs index d3b0c3c..fcfaa60 100644 --- a/websock-mux/examples/echo-server-mux.rs +++ b/websock-mux/examples/echo-server-mux.rs @@ -36,16 +36,15 @@ async fn main() -> anyhow::Result<()> { let args = Args::parse(); - let mut builder = ServerBuilder::new().with_addr(args.addr); + let mut builder = ServerBuilder::new().with_addr(args.addr).with_default_alpn(); match (&args.tls_cert, &args.tls_key) { (Some(cert_path), Some(key_path)) => { let (chain, key) = load_pem_cert_and_key(cert_path, key_path)?; - let mut cfg = rustls::ServerConfig::builder() + let cfg = rustls::ServerConfig::builder() .with_no_client_auth() .with_single_cert(chain, key) .map_err(|e| anyhow::anyhow!(e.to_string()))?; - cfg.alpn_protocols = vec![b"http/1.1".to_vec(), b"h2".to_vec()]; builder = builder.with_tls_config(cfg); tracing::info!("TLS enabled (wss://)"); } diff --git a/websock/examples/echo-server.rs b/websock/examples/echo-server.rs index 1a039e6..6a690ef 100644 --- a/websock/examples/echo-server.rs +++ b/websock/examples/echo-server.rs @@ -36,7 +36,7 @@ async fn main() -> anyhow::Result<()> { let args = Args::parse(); - let mut builder = ServerBuilder::new().with_addr(args.addr); + let mut builder = ServerBuilder::new().with_addr(args.addr).with_default_alpn(); match (&args.tls_cert, &args.tls_key) { (Some(cert_path), Some(key_path)) => { From 2692da9fa1c1291fb5534ccf4941201c4610422c Mon Sep 17 00:00:00 2001 From: shellrow Date: Sat, 17 Jan 2026 22:31:33 +0900 Subject: [PATCH 13/16] Update WASM demo --- websock-wasm-demo/Cargo.toml | 2 +- websock-wasm-demo/index.html | 39 +++++++++--- websock-wasm-demo/src/echo.rs | 2 +- websock-wasm-demo/src/lib.rs | 110 +++++++++++++++++++++------------- 4 files changed, 101 insertions(+), 52 deletions(-) diff --git a/websock-wasm-demo/Cargo.toml b/websock-wasm-demo/Cargo.toml index f71516e..b8cfc26 100644 --- a/websock-wasm-demo/Cargo.toml +++ b/websock-wasm-demo/Cargo.toml @@ -14,5 +14,5 @@ websock-mux = { path = "../websock-mux" } wasm-bindgen = "0.2" wasm-bindgen-futures = "0.4" console_error_panic_hook = "0.1" -web-sys = { version = "0.3", features = ["console", "Document", "HtmlButtonElement", "Window"] } +web-sys = { version = "0.3", features = ["console", "Document", "HtmlInputElement", "HtmlButtonElement", "Window"] } futures-util = { version = "0.3", features = ["sink"] } diff --git a/websock-wasm-demo/index.html b/websock-wasm-demo/index.html index 881e4d7..47ebf50 100644 --- a/websock-wasm-demo/index.html +++ b/websock-wasm-demo/index.html @@ -1,13 +1,38 @@ - + - websock wasm demo + + websock-wasm-demo + - - - -

Open DevTools Console to see logs.

+

websock-wasm-demo

+ +
+ +
+ +
+ + + +
+ + + Tip: For TLS testing use wss:// and a certificate that your browser trusts. + + +

log

+

+
+    
   
-
\ No newline at end of file
+
diff --git a/websock-wasm-demo/src/echo.rs b/websock-wasm-demo/src/echo.rs
index 931127c..6739b44 100644
--- a/websock-wasm-demo/src/echo.rs
+++ b/websock-wasm-demo/src/echo.rs
@@ -1,7 +1,7 @@
 use futures_util::{SinkExt, StreamExt};
 use websock::{ClientBuilder, Message};
 
-pub const DEFAULT_ECHO_URL: &str = "wss://echo.websocket.org";
+//pub const DEFAULT_ECHO_URL: &str = "wss://echo.websocket.org";
 
 pub async fn run_conn_demo(url: &str, log: impl Fn(&str)) {
     log("[conn demo] start");
diff --git a/websock-wasm-demo/src/lib.rs b/websock-wasm-demo/src/lib.rs
index 7123840..ed59df0 100644
--- a/websock-wasm-demo/src/lib.rs
+++ b/websock-wasm-demo/src/lib.rs
@@ -1,76 +1,100 @@
-//! Minimal WebAssembly demo that exercises the websock API in the browser.
+//! Minimal WebAssembly demo for WebSocket + WebSocket-mux in the browser.
 
 #![cfg(all(target_arch = "wasm32", not(target_os = "wasi")))]
 
 mod echo;
 mod echo_mux;
 
+use wasm_bindgen::JsCast;
 use wasm_bindgen::prelude::*;
 use wasm_bindgen_futures::spawn_local;
-use web_sys::{Document, HtmlButtonElement};
+use web_sys::{Document, HtmlButtonElement, HtmlInputElement};
 
-/// Obtain the active document.
 fn document() -> Document {
     web_sys::window().unwrap().document().unwrap()
 }
 
-/// Log a message to the browser console.
-fn log(msg: &str) {
-    web_sys::console::log_1(&msg.into());
+fn by_id(id: &str) -> T {
+    document()
+        .get_element_by_id(id)
+        .unwrap_or_else(|| panic!("missing element: #{id}"))
+        .dyn_into()
+        .unwrap()
 }
 
-/// Bind a click handler to a button by element ID.
-fn hook_button(id: &str, f: impl Fn() + 'static) {
+fn append_log(msg: &str) {
     let doc = document();
-    let el = doc
-        .get_element_by_id(id)
-        .unwrap_or_else(|| panic!("missing element: #{id}"));
-    let btn: HtmlButtonElement = el.dyn_into().unwrap();
+    let el = doc.get_element_by_id("log").unwrap();
+    let mut text = el.text_content().unwrap_or_default();
+    text.push_str(msg);
+    text.push('\n');
+    el.set_text_content(Some(&text));
 
-    let cb = Closure::::new(move || {
-        f();
-    });
-
-    btn.add_event_listener_with_callback("click", cb.as_ref().unchecked_ref())
-        .unwrap();
-
-    // Keep the handler alive for the lifetime of the page.
-    cb.forget();
+    // Also mirror into DevTools for convenience.
+    web_sys::console::log_1(&msg.into());
 }
 
-fn prompt_url(default_url: &str) -> String {
-    let w = web_sys::window().unwrap();
-    match w.prompt_with_message_and_default("URL?", default_url) {
-        Ok(Some(s)) if !s.trim().is_empty() => s,
-        _ => default_url.to_string(),
+fn get_url() -> String {
+    let input: HtmlInputElement = by_id("url");
+    let url = input.value();
+    if url.trim().is_empty() {
+        echo_mux::DEFAULT_MUX_URL.to_string()
+    } else {
+        url
     }
 }
 
 #[wasm_bindgen(start)]
-/// Entry point invoked by the browser when the module loads.
 pub fn start() {
     console_error_panic_hook::set_once();
 
-    hook_button("btn-conn", || {
-        let url = prompt_url(echo::DEFAULT_ECHO_URL);
-        spawn_local(async move {
-            echo::run_conn_demo(&url, log).await;
+    // Default URL.
+    {
+        let input: HtmlInputElement = by_id("url");
+        if input.value().trim().is_empty() {
+            input.set_value(echo_mux::DEFAULT_MUX_URL);
+        }
+    }
+
+    let btn_conn: HtmlButtonElement = by_id("btn-conn");
+    let btn_split: HtmlButtonElement = by_id("btn-split");
+    let btn_mux_bi: HtmlButtonElement = by_id("btn-mux-bi");
+
+    // conn demo
+    {
+        let cb = Closure::::new(move || {
+            let url = get_url();
+            spawn_local(async move {
+                echo::run_conn_demo(&url, |m| append_log(m)).await;
+            });
         });
-    });
+        btn_conn.set_onclick(Some(cb.as_ref().unchecked_ref()));
+        cb.forget();
+    }
 
-    hook_button("btn-split", || {
-        let url = prompt_url(echo::DEFAULT_ECHO_URL);
-        spawn_local(async move {
-            echo::run_split_demo(&url, log).await;
+    // split demo
+    {
+        let cb = Closure::::new(move || {
+            let url = get_url();
+            spawn_local(async move {
+                echo::run_split_demo(&url, |m| append_log(m)).await;
+            });
         });
-    });
+        btn_split.set_onclick(Some(cb.as_ref().unchecked_ref()));
+        cb.forget();
+    }
 
-    hook_button("btn-mux-bi", || {
-        let url = prompt_url(echo_mux::DEFAULT_MUX_URL);
-        spawn_local(async move {
-            echo_mux::run_mux_bi_demo(&url, log).await;
+    // mux bi demo
+    {
+        let cb = Closure::::new(move || {
+            let url = get_url();
+            spawn_local(async move {
+                echo_mux::run_mux_bi_demo(&url, |m| append_log(m)).await;
+            });
         });
-    });
+        btn_mux_bi.set_onclick(Some(cb.as_ref().unchecked_ref()));
+        cb.forget();
+    }
 
-    log("ready: click [conn demo] / [split demo] / [mux bi demo]");
+    append_log("ready");
 }

From 5a873510102b442621f2c4c3024e36d017e421cd Mon Sep 17 00:00:00 2001
From: shellrow 
Date: Sat, 17 Jan 2026 22:33:00 +0900
Subject: [PATCH 14/16] Format code with cargo fmt

---
 websock-mux/examples/echo-client-mux.rs | 5 ++++-
 websock-mux/examples/echo-server-mux.rs | 4 +++-
 websock/examples/echo-server.rs         | 4 +++-
 3 files changed, 10 insertions(+), 3 deletions(-)

diff --git a/websock-mux/examples/echo-client-mux.rs b/websock-mux/examples/echo-client-mux.rs
index 0886648..37d0778 100644
--- a/websock-mux/examples/echo-client-mux.rs
+++ b/websock-mux/examples/echo-client-mux.rs
@@ -103,7 +103,10 @@ fn build_client(url: &Url, args: &Args) -> anyhow::Result {
         TlsClientConfigBuilder::new_with_native_certs()?.build()
     };
 
-    Ok(ClientBuilder::new().with_tls_config(tls_cfg).with_default_alpn().build())
+    Ok(ClientBuilder::new()
+        .with_tls_config(tls_cfg)
+        .with_default_alpn()
+        .build())
 }
 
 /// Determine whether a URL points to a loopback host.
diff --git a/websock-mux/examples/echo-server-mux.rs b/websock-mux/examples/echo-server-mux.rs
index fcfaa60..60eebc1 100644
--- a/websock-mux/examples/echo-server-mux.rs
+++ b/websock-mux/examples/echo-server-mux.rs
@@ -36,7 +36,9 @@ async fn main() -> anyhow::Result<()> {
 
     let args = Args::parse();
 
-    let mut builder = ServerBuilder::new().with_addr(args.addr).with_default_alpn();
+    let mut builder = ServerBuilder::new()
+        .with_addr(args.addr)
+        .with_default_alpn();
 
     match (&args.tls_cert, &args.tls_key) {
         (Some(cert_path), Some(key_path)) => {
diff --git a/websock/examples/echo-server.rs b/websock/examples/echo-server.rs
index 6a690ef..a009bff 100644
--- a/websock/examples/echo-server.rs
+++ b/websock/examples/echo-server.rs
@@ -36,7 +36,9 @@ async fn main() -> anyhow::Result<()> {
 
     let args = Args::parse();
 
-    let mut builder = ServerBuilder::new().with_addr(args.addr).with_default_alpn();
+    let mut builder = ServerBuilder::new()
+        .with_addr(args.addr)
+        .with_default_alpn();
 
     match (&args.tls_cert, &args.tls_key) {
         (Some(cert_path), Some(key_path)) => {

From 240ff615a3b8ff3473311b2252e743aee50d01f9 Mon Sep 17 00:00:00 2001
From: shellrow 
Date: Sun, 18 Jan 2026 10:51:41 +0900
Subject: [PATCH 15/16] Create .gitattributes

---
 .gitattributes | 1 +
 1 file changed, 1 insertion(+)
 create mode 100644 .gitattributes

diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000..1a7331f
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1 @@
+/devcert/** -linguist-detectable

From 0821923da7b7de35c9c26fdf6b57dde3b36be558 Mon Sep 17 00:00:00 2001
From: shellrow 
Date: Sun, 18 Jan 2026 10:56:22 +0900
Subject: [PATCH 16/16] Bump version to 0.2.0

---
 Cargo.toml | 14 +++++++-------
 1 file changed, 7 insertions(+), 7 deletions(-)

diff --git a/Cargo.toml b/Cargo.toml
index 734ab95..489eb63 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -13,15 +13,15 @@ members = [
 ]
 
 [workspace.package]
-version = "0.1.0"
+version = "0.2.0"
 edition = "2024"
 authors = ["shellrow "]
 
 [workspace.dependencies]
-websock-proto = { path = "websock-proto", version = "0.1.0" }
-websock-tungstenite = { path = "websock-tungstenite", version = "0.1.0" }
-websock-wasm = { path = "websock-wasm", version = "0.1.0" }
-websock-mux-proto = { path = "websock-mux-proto", version = "0.1.0" }
-websock-tungstenite-mux = { path = "websock-tungstenite-mux", version = "0.1.0" }
-websock-wasm-mux = { path = "websock-wasm-mux", version = "0.1.0" }
+websock-proto = { path = "websock-proto", version = "0.2.0" }
+websock-tungstenite = { path = "websock-tungstenite", version = "0.2.0" }
+websock-wasm = { path = "websock-wasm", version = "0.2.0" }
+websock-mux-proto = { path = "websock-mux-proto", version = "0.2.0" }
+websock-tungstenite-mux = { path = "websock-tungstenite-mux", version = "0.2.0" }
+websock-wasm-mux = { path = "websock-wasm-mux", version = "0.2.0" }
 bytes = "1"