diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..1a7331f --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +/devcert/** -linguist-detectable diff --git a/Cargo.toml b/Cargo.toml index d5d9ff6..489eb63 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,16 +5,23 @@ members = [ "websock-proto", "websock-tungstenite", "websock-wasm", - "websock-wasm-demo" + "websock-wasm-demo", + "websock-mux", + "websock-tungstenite-mux", + "websock-wasm-mux", + "websock-mux-proto" ] [workspace.package] -version = "0.1.0" +version = "0.2.0" edition = "2024" authors = ["shellrow "] [workspace.dependencies] -websock-proto = { path = "websock-proto", version = "0.1.0" } -websock-tungstenite = { path = "websock-tungstenite", version = "0.1.0" } -websock-wasm = { path = "websock-wasm", version = "0.1.0" } +websock-proto = { path = "websock-proto", version = "0.2.0" } +websock-tungstenite = { path = "websock-tungstenite", version = "0.2.0" } +websock-wasm = { path = "websock-wasm", version = "0.2.0" } +websock-mux-proto = { path = "websock-mux-proto", version = "0.2.0" } +websock-tungstenite-mux = { path = "websock-tungstenite-mux", version = "0.2.0" } +websock-wasm-mux = { path = "websock-wasm-mux", version = "0.2.0" } bytes = "1" diff --git a/devcert/.gitignore b/devcert/.gitignore new file mode 100644 index 0000000..89a9fbc --- /dev/null +++ b/devcert/.gitignore @@ -0,0 +1,4 @@ +*.crt +*.hex +*.key +*.fingerprint diff --git a/devcert/generate.sh b/devcert/generate.sh new file mode 100755 index 0000000..bb60067 --- /dev/null +++ b/devcert/generate.sh @@ -0,0 +1,50 @@ +#!/usr/bin/env bash +set -euo pipefail + +cd "$(dirname "${BASH_SOURCE[0]}")" + +CERT_NAME="localhost" +DAYS=10 +CONF="openssl.conf" + +echo "==> Generating self-signed certificate (${CERT_NAME})" +echo " validity: ${DAYS} days" +echo " config: ${CONF}" + +# Generate ECDSA P-256 private key +openssl ecparam \ + -genkey \ + -name prime256v1 \ + -out "${CERT_NAME}.key" + +# Generate self-signed certificate +openssl req \ + -x509 \ + -sha256 \ + -nodes \ + -days "${DAYS}" \ + -key "${CERT_NAME}.key" \ + -out "${CERT_NAME}.crt" \ + -config "${CONF}" \ + -extensions v3_req + +# Generate raw SHA-256 hash (DER -> hex, no colons) +openssl x509 \ + -in "${CERT_NAME}.crt" \ + -outform der \ +| openssl dgst -sha256 -binary \ +| xxd -p -c 256 \ +> "${CERT_NAME}.hex" + +# Also print human-readable fingerprint +openssl x509 \ + -in "${CERT_NAME}.crt" \ + -noout \ + -fingerprint -sha256 \ +> "${CERT_NAME}.fingerprint" + +echo "==> Done" +echo " - ${CERT_NAME}.crt" +echo " - ${CERT_NAME}.key" +echo " - ${CERT_NAME}.hex" +echo " - ${CERT_NAME}.fingerprint" diff --git a/devcert/openssl.conf b/devcert/openssl.conf new file mode 100644 index 0000000..fa64821 --- /dev/null +++ b/devcert/openssl.conf @@ -0,0 +1,22 @@ +[req] +distinguished_name = req_distinguished_name +req_extensions = req_ext +x509_extensions = v3_req +prompt = no + +[req_distinguished_name] +CN = localhost + +[req_ext] +subjectAltName = @alt_names + +[v3_req] +basicConstraints = CA:FALSE +keyUsage = digitalSignature, keyEncipherment +extendedKeyUsage = serverAuth +subjectAltName = @alt_names + +[alt_names] +DNS.1 = localhost +IP.1 = 127.0.0.1 +IP.2 = ::1 diff --git a/websock-mux-proto/Cargo.toml b/websock-mux-proto/Cargo.toml new file mode 100644 index 0000000..cb83977 --- /dev/null +++ b/websock-mux-proto/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "websock-mux-proto" +version.workspace = true +edition.workspace = true +authors.workspace = true +description = "Protocol for multiplexing WebSocket logical streams" +repository = "https://github.com/foctal/websock" +readme = "../README.md" +keywords = ["network", "websocket", "multiplex"] +categories = ["network-programming", "web-programming"] +license = "MIT" + +[dependencies] +bytes = { workspace = true } +thiserror = "2" +websock-proto = { workspace = true } + +[target.'cfg(not(target_arch = "wasm32"))'.dependencies] +tokio = { version = "1", features = ["io-util"] } diff --git a/websock-mux-proto/src/lib.rs b/websock-mux-proto/src/lib.rs new file mode 100644 index 0000000..c4dc2d7 --- /dev/null +++ b/websock-mux-proto/src/lib.rs @@ -0,0 +1,7 @@ +pub mod stream; +pub mod varint; + +pub use stream::{Frame, FrameDecodeError, StreamDir, StreamId}; +pub use varint::{VarInt, VarIntBoundsExceeded, VarIntUnexpectedEnd}; + +pub const SUBPROTOCOL: &str = "websock-mux-1"; diff --git a/websock-mux-proto/src/stream.rs b/websock-mux-proto/src/stream.rs new file mode 100644 index 0000000..74886c0 --- /dev/null +++ b/websock-mux-proto/src/stream.rs @@ -0,0 +1,173 @@ +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +use crate::varint::{VarInt, VarIntBoundsExceeded, VarIntUnexpectedEnd}; + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub enum StreamDir { + Bi, + Uni, +} + +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +pub struct StreamId(pub u64); + +impl StreamId { + pub fn new( + counter: u64, + is_server: bool, + dir: StreamDir, + ) -> Result { + let initiator = if is_server { 1 } else { 0 }; + let dir_bit = match dir { + StreamDir::Bi => 0, + StreamDir::Uni => 1, + }; + let value = counter.checked_shl(2).ok_or(VarIntBoundsExceeded)? + | ((dir_bit as u64) << 1) + | (initiator as u64); + VarInt::from_u64(value)?; + Ok(Self(value)) + } + + pub fn dir(self) -> StreamDir { + if (self.0 >> 1) & 1 == 1 { + StreamDir::Uni + } else { + StreamDir::Bi + } + } + + pub fn initiator_is_server(self) -> bool { + self.0 & 1 == 1 + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Frame { + OpenUni { + id: StreamId, + }, + OpenBi { + id: StreamId, + }, + Stream { + id: StreamId, + data: Bytes, + fin: bool, + }, + ResetStream { + id: StreamId, + code: u64, + }, + StopSending { + id: StreamId, + code: u64, + }, + ConnectionClose { + code: u64, + reason: String, + }, +} + +impl Frame { + pub fn encode(&self) -> BytesMut { + let mut buf = BytesMut::new(); + match self { + Frame::OpenUni { id } => { + VarInt(0).encode(&mut buf); + VarInt(id.0).encode(&mut buf); + } + Frame::OpenBi { id } => { + VarInt(1).encode(&mut buf); + VarInt(id.0).encode(&mut buf); + } + Frame::Stream { id, data, fin } => { + VarInt(2).encode(&mut buf); + VarInt(id.0).encode(&mut buf); + VarInt(u64::from(*fin)).encode(&mut buf); + VarInt(data.len() as u64).encode(&mut buf); + buf.put_slice(data); + } + Frame::ResetStream { id, code } => { + VarInt(3).encode(&mut buf); + VarInt(id.0).encode(&mut buf); + VarInt(*code).encode(&mut buf); + } + Frame::StopSending { id, code } => { + VarInt(4).encode(&mut buf); + VarInt(id.0).encode(&mut buf); + VarInt(*code).encode(&mut buf); + } + Frame::ConnectionClose { code, reason } => { + VarInt(5).encode(&mut buf); + VarInt(*code).encode(&mut buf); + VarInt(reason.len() as u64).encode(&mut buf); + buf.put_slice(reason.as_bytes()); + } + } + buf + } + + pub fn decode(buf: &mut B) -> Result { + let tag = VarInt::decode(buf)?.into_inner(); + match tag { + 0 => Ok(Frame::OpenUni { + id: StreamId(VarInt::decode(buf)?.into_inner()), + }), + 1 => Ok(Frame::OpenBi { + id: StreamId(VarInt::decode(buf)?.into_inner()), + }), + 2 => { + let id = StreamId(VarInt::decode(buf)?.into_inner()); + let fin = VarInt::decode(buf)?.into_inner() != 0; + let len = VarInt::decode(buf)?.into_inner() as usize; + if buf.remaining() < len { + return Err(FrameDecodeError::UnexpectedEnd); + } + let mut data = vec![0u8; len]; + buf.copy_to_slice(&mut data); + Ok(Frame::Stream { + id, + data: Bytes::from(data), + fin, + }) + } + 3 => Ok(Frame::ResetStream { + id: StreamId(VarInt::decode(buf)?.into_inner()), + code: VarInt::decode(buf)?.into_inner(), + }), + 4 => Ok(Frame::StopSending { + id: StreamId(VarInt::decode(buf)?.into_inner()), + code: VarInt::decode(buf)?.into_inner(), + }), + 5 => { + let code = VarInt::decode(buf)?.into_inner(); + let len = VarInt::decode(buf)?.into_inner() as usize; + if buf.remaining() < len { + return Err(FrameDecodeError::UnexpectedEnd); + } + let mut data = vec![0u8; len]; + buf.copy_to_slice(&mut data); + let reason = String::from_utf8(data).map_err(|_| FrameDecodeError::InvalidUtf8)?; + Ok(Frame::ConnectionClose { code, reason }) + } + _ => Err(FrameDecodeError::UnknownTag(tag)), + } + } +} + +#[derive(Debug, thiserror::Error)] +pub enum FrameDecodeError { + #[error("unexpected end of buffer")] + UnexpectedEnd, + #[error("unknown frame tag {0}")] + UnknownTag(u64), + #[error("invalid utf-8 in reason")] + InvalidUtf8, +} + +impl From for FrameDecodeError { + fn from(_: VarIntUnexpectedEnd) -> Self { + FrameDecodeError::UnexpectedEnd + } +} diff --git a/websock-mux-proto/src/varint.rs b/websock-mux-proto/src/varint.rs new file mode 100644 index 0000000..cc508b7 --- /dev/null +++ b/websock-mux-proto/src/varint.rs @@ -0,0 +1,241 @@ +//! QUIC variable-length integer encoding and decoding. + +// Based on Quinn: https://github.com/quinn-rs/quinn/tree/main/quinn-proto/src +// Licensed under Apache-2.0 OR MIT + +use std::{convert::TryInto, fmt}; + +#[cfg(not(target_arch = "wasm32"))] +use std::io::Cursor; + +use bytes::{Buf, BufMut}; +use thiserror::Error; + +#[cfg(not(target_arch = "wasm32"))] +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +/// An integer less than 2^62. +/// +/// Values of this type are suitable for encoding as QUIC variable-length integer. +// Rust does not currently model that the top two bits are reserved for the length tag. +#[derive(Default, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] +pub struct VarInt(pub(crate) u64); + +impl VarInt { + /// The largest representable value. + pub const MAX: Self = Self((1 << 62) - 1); + /// The largest encoded value length. + pub const MAX_SIZE: usize = 8; + + /// Construct a `VarInt` infallibly. + pub const fn from_u32(x: u32) -> Self { + Self(x as u64) + } + + /// Succeeds if `x` < 2^62. + pub fn from_u64(x: u64) -> Result { + if x < 2u64.pow(62) { + Ok(Self(x)) + } else { + Err(VarIntBoundsExceeded) + } + } + + /// Create a `VarInt` without checking the bounds. + /// + /// # Safety + /// + /// `x` must be less than 2^62. + pub const unsafe fn from_u64_unchecked(x: u64) -> Self { + Self(x) + } + + /// Extract the integer value. + pub const fn into_inner(self) -> u64 { + self.0 + } + + /// Compute the number of bytes needed to encode this value. + pub fn size(self) -> usize { + let x = self.0; + if x < 2u64.pow(6) { + 1 + } else if x < 2u64.pow(14) { + 2 + } else if x < 2u64.pow(30) { + 4 + } else if x < 2u64.pow(62) { + 8 + } else { + unreachable!("malformed VarInt"); + } + } +} + +impl From for u64 { + fn from(x: VarInt) -> Self { + x.0 + } +} + +impl From for VarInt { + fn from(x: u8) -> Self { + Self(x.into()) + } +} + +impl From for VarInt { + fn from(x: u16) -> Self { + Self(x.into()) + } +} + +impl From for VarInt { + fn from(x: u32) -> Self { + Self(x.into()) + } +} + +impl std::convert::TryFrom for VarInt { + type Error = VarIntBoundsExceeded; + /// Succeeds if `x` < 2^62. + fn try_from(x: u64) -> Result { + Self::from_u64(x) + } +} + +impl std::convert::TryFrom for VarInt { + type Error = VarIntBoundsExceeded; + /// Succeeds if `x` < 2^62. + fn try_from(x: u128) -> Result { + Self::from_u64(x.try_into().map_err(|_| VarIntBoundsExceeded)?) + } +} + +impl std::convert::TryFrom for VarInt { + type Error = VarIntBoundsExceeded; + /// Succeeds if `x` < 2^62. + fn try_from(x: usize) -> Result { + Self::try_from(x as u64) + } +} + +impl fmt::Debug for VarInt { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +impl fmt::Display for VarInt { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +impl VarInt { + pub fn decode(r: &mut B) -> Result { + if !r.has_remaining() { + return Err(VarIntUnexpectedEnd); + } + let mut buf = [0; 8]; + buf[0] = r.get_u8(); + let tag = buf[0] >> 6; + buf[0] &= 0b0011_1111; + let x = match tag { + 0b00 => u64::from(buf[0]), + 0b01 => { + if r.remaining() < 1 { + return Err(VarIntUnexpectedEnd); + } + r.copy_to_slice(&mut buf[1..2]); + u64::from(u16::from_be_bytes(buf[..2].try_into().unwrap())) + } + 0b10 => { + if r.remaining() < 3 { + return Err(VarIntUnexpectedEnd); + } + r.copy_to_slice(&mut buf[1..4]); + u64::from(u32::from_be_bytes(buf[..4].try_into().unwrap())) + } + 0b11 => { + if r.remaining() < 7 { + return Err(VarIntUnexpectedEnd); + } + r.copy_to_slice(&mut buf[1..8]); + u64::from_be_bytes(buf) + } + _ => unreachable!(), + }; + Ok(Self(x)) + } + + // Read a varint from an async stream. + #[cfg(not(target_arch = "wasm32"))] + pub async fn read(stream: &mut S) -> Result { + // Eight bytes is the maximum encoded length. + let mut buf = [0; 8]; + + // Read the first byte because it encodes the length tag. + stream + .read_exact(&mut buf[0..1]) + .await + .map_err(|_| VarIntUnexpectedEnd)?; + + // 0b00 = 1 byte, 0b01 = 2 bytes, 0b10 = 4 bytes, 0b11 = 8 bytes. + let size = 1 << (buf[0] >> 6); + stream + .read_exact(&mut buf[1..size]) + .await + .map_err(|_| VarIntUnexpectedEnd)?; + + // Use a cursor to decode from the stack buffer. + let mut cursor = Cursor::new(&buf[..size]); + let v = VarInt::decode(&mut cursor).unwrap(); + + Ok(v) + } + + pub fn encode(&self, w: &mut B) { + let x = self.0; + if x < 2u64.pow(6) { + w.put_u8(x as u8); + } else if x < 2u64.pow(14) { + w.put_u16((0b01 << 14) | x as u16); + } else if x < 2u64.pow(30) { + w.put_u32((0b10 << 30) | x as u32); + } else if x < 2u64.pow(62) { + w.put_u64((0b11 << 62) | x); + } else { + unreachable!("malformed VarInt") + } + } + + #[cfg(not(target_arch = "wasm32"))] + pub async fn write( + &self, + stream: &mut S, + ) -> Result<(), VarIntUnexpectedEnd> { + // Keep the temporary buffer on the stack to avoid allocation. + let mut buf = [0u8; 8]; + let mut cursor: &mut [u8] = &mut buf; + self.encode(&mut cursor); + let size = 8 - cursor.len(); + + let mut cursor = &buf[..size]; + stream + .write_all_buf(&mut cursor) + .await + .map_err(|_| VarIntUnexpectedEnd)?; + + Ok(()) + } +} + +/// Error returned when constructing a `VarInt` from a value >= 2^62. +#[derive(Debug, Copy, Clone, Eq, PartialEq, Error)] +#[error("value too large for varint encoding")] +pub struct VarIntBoundsExceeded; + +#[derive(Error, Debug, Copy, Clone, Eq, PartialEq)] +#[error("unexpected end of buffer")] +pub struct VarIntUnexpectedEnd; diff --git a/websock-mux-proto/tests/frame.rs b/websock-mux-proto/tests/frame.rs new file mode 100644 index 0000000..3b34056 --- /dev/null +++ b/websock-mux-proto/tests/frame.rs @@ -0,0 +1,107 @@ +use bytes::{Bytes, BytesMut}; +use std::io::Cursor; + +use websock_mux_proto::{Frame, FrameDecodeError, StreamDir, StreamId, VarInt}; + +fn roundtrip(frame: &Frame) { + let encoded = frame.encode().freeze(); + let mut cur = Cursor::new(encoded); + let decoded = Frame::decode(&mut cur).unwrap(); + assert_eq!(&decoded, frame); +} + +#[test] +fn frame_roundtrip_all_variants() { + let bi_client = StreamId::new(0, false, StreamDir::Bi).unwrap(); + let uni_client = StreamId::new(1, false, StreamDir::Uni).unwrap(); + let bi_server = StreamId::new(2, true, StreamDir::Bi).unwrap(); + + let frames = vec![ + Frame::OpenUni { id: uni_client }, + Frame::OpenBi { id: bi_client }, + Frame::Stream { + id: bi_client, + data: Bytes::from_static(b"hello"), + fin: false, + }, + Frame::Stream { + id: bi_server, + data: Bytes::from(vec![0u8; 1024]), + fin: true, + }, + Frame::ResetStream { + id: bi_client, + code: 0, + }, + Frame::ResetStream { + id: bi_client, + code: 0xdead_beef, + }, + Frame::StopSending { + id: bi_server, + code: 42, + }, + Frame::ConnectionClose { + code: 0, + reason: "bye".to_string(), + }, + Frame::ConnectionClose { + code: 100, + reason: "test".to_string(), + }, + ]; + + for f in frames { + roundtrip(&f); + } +} + +#[test] +fn frame_decode_unknown_tag() { + let mut buf = BytesMut::new(); + VarInt::from_u32(99).encode(&mut buf); // unknown tag + let mut cur = Cursor::new(buf.freeze()); + + let err = Frame::decode(&mut cur).unwrap_err(); + match err { + FrameDecodeError::UnknownTag(99) => {} + other => panic!("unexpected error: {other:?}"), + } +} + +#[test] +fn frame_decode_unexpected_end_stream_payload() { + // Build a Stream frame but truncate the payload. + let id = StreamId::new(0, false, StreamDir::Bi).unwrap(); + + let mut buf = BytesMut::new(); + VarInt::from_u32(2).encode(&mut buf); // Stream + VarInt::from_u64(id.0).unwrap().encode(&mut buf); + VarInt::from_u32(0).encode(&mut buf); // fin = false + VarInt::from_u32(5).encode(&mut buf); // len=5 + buf.extend_from_slice(b"he"); // only 2 bytes, should be 5 + + let mut cur = Cursor::new(buf.freeze()); + let err = Frame::decode(&mut cur).unwrap_err(); + match err { + FrameDecodeError::UnexpectedEnd => {} + other => panic!("unexpected error: {other:?}"), + } +} + +#[test] +fn frame_decode_invalid_utf8_reason() { + // ConnectionClose: tag=5, code=0, len=2, bytes=[0xFF,0xFF] + let mut buf = BytesMut::new(); + VarInt::from_u32(5).encode(&mut buf); + VarInt::from_u32(0).encode(&mut buf); + VarInt::from_u32(2).encode(&mut buf); + buf.extend_from_slice(&[0xFF, 0xFF]); + + let mut cur = Cursor::new(buf.freeze()); + let err = Frame::decode(&mut cur).unwrap_err(); + match err { + FrameDecodeError::InvalidUtf8 => {} + other => panic!("unexpected error: {other:?}"), + } +} diff --git a/websock-mux-proto/tests/varint.rs b/websock-mux-proto/tests/varint.rs new file mode 100644 index 0000000..4b68074 --- /dev/null +++ b/websock-mux-proto/tests/varint.rs @@ -0,0 +1,55 @@ +use bytes::BytesMut; +use std::io::Cursor; + +use websock_mux_proto::{VarInt, VarIntBoundsExceeded}; + +#[test] +fn varint_roundtrip_boundaries() { + // (value, expected_encoded_size) + let cases: &[(u64, usize)] = &[ + (0, 1), + (1, 1), + (63, 1), + (64, 2), + (16383, 2), + (16384, 4), + ((1u64 << 30) - 1, 4), + (1u64 << 30, 8), + ((1u64 << 62) - 1, 8), + ]; + + for &(value, expected_size) in cases { + let v = VarInt::from_u64(value).unwrap(); + assert_eq!(v.size(), expected_size, "size mismatch for {}", value); + + let mut buf = BytesMut::new(); + v.encode(&mut buf); + + // encoded length should match + assert_eq!( + buf.len(), + expected_size, + "encoded length mismatch for {}", + value + ); + + let mut cur = Cursor::new(buf.freeze()); + let decoded = VarInt::decode(&mut cur).unwrap(); + assert_eq!( + decoded.into_inner(), + value, + "roundtrip mismatch for {}", + value + ); + + // should consume all + assert_eq!(cur.position() as usize, expected_size); + } +} + +#[test] +fn varint_rejects_too_large() { + let too_large = 1u64 << 62; + let err = VarInt::from_u64(too_large).unwrap_err(); + assert_eq!(err, VarIntBoundsExceeded); +} diff --git a/websock-mux/Cargo.toml b/websock-mux/Cargo.toml new file mode 100644 index 0000000..898400f --- /dev/null +++ b/websock-mux/Cargo.toml @@ -0,0 +1,40 @@ +[package] +name = "websock-mux" +version.workspace = true +edition.workspace = true +authors.workspace = true +description = "WebSocket multiplexing layer for logical streams" +repository = "https://github.com/foctal/websock" +readme = "../README.md" +keywords = ["network", "websocket", "multiplex", "native", "wasm"] +categories = ["network-programming", "web-programming"] +license = "MIT" + +[dependencies] +websock-proto = { workspace = true } +websock-mux-proto = { workspace = true } + +[target.'cfg(not(target_arch = "wasm32"))'.dependencies] +websock-tungstenite-mux = { workspace = true } + +[target.'cfg(target_arch = "wasm32")'.dependencies] +websock-wasm-mux = { workspace = true } + +[dev-dependencies] +tokio = { version = "1", features = ["rt", "rt-multi-thread", "macros"] } +bytes = { workspace = true } +anyhow = "1" +tracing = "0.1" +tracing-subscriber = "0.3" +clap = { version = "4", features = ["derive"] } +url = "2" +rustls = { version = "0.23", default-features = false, features = ["ring", "std"] } +rustls-pemfile = "2.2" + +[[example]] +name = "echo-client-mux" +path = "examples/echo-client-mux.rs" + +[[example]] +name = "echo-server-mux" +path = "examples/echo-server-mux.rs" diff --git a/websock-mux/examples/echo-client-mux.rs b/websock-mux/examples/echo-client-mux.rs new file mode 100644 index 0000000..37d0778 --- /dev/null +++ b/websock-mux/examples/echo-client-mux.rs @@ -0,0 +1,119 @@ +//! Echo client example for the websock mux transport. +//! +//! This demonstrates opening a bidirectional stream and echoing bytes. + +use clap::Parser; +use rustls::{RootCertStore, client::ClientConfig}; +use std::path; +use tracing::Level; +use tracing_subscriber::FmtSubscriber; +use url::Url; +use websock_mux::{ + Client, ClientBuilder, + tls::{self, TlsClientConfigBuilder}, +}; + +const DEFAULT_ECHO_URL: &str = "ws://127.0.0.1:9001"; + +/// Command-line arguments for the echo mux client. +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// WebSocket server URL (ws:// or wss://). + #[arg(short, long)] + url: Option, + + /// Accept the server certificates at this path, encoded as PEM. + #[arg(long)] + tls_cert: Option, + + /// Dangerous: Disable TLS certificate verification. + #[arg(long, default_value = "false")] + tls_disable_verify: bool, +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + // Initialize tracing subscriber for logging in this example. + let subscriber = FmtSubscriber::builder() + .with_max_level(Level::DEBUG) + .finish(); + + tracing::subscriber::set_global_default(subscriber) + .expect("failed to set global tracing subscriber"); + + let args = Args::parse(); + let url = args + .url + .clone() + .unwrap_or_else(|| Url::parse(DEFAULT_ECHO_URL).expect("default url")); + + let client: Client = build_client(&url, &args)?; + tracing::info!("connecting to {}", url); + + let session = client.connect(url.as_str()).await?; + tracing::info!("connected"); + + tracing::info!("opening bidirectional stream"); + let (send, mut recv) = session.open_bi().await?; + tracing::info!("opened bidirectional stream"); + + send.write_all(b"hello mux").await?; + send.finish().await?; + + while let Some(chunk) = recv.read_chunk(1024).await? { + tracing::info!("Received: {}", String::from_utf8_lossy(&chunk)); + } + Ok(()) +} + +/// Build a mux client based on the URL scheme and CLI options. +fn build_client(url: &Url, args: &Args) -> anyhow::Result { + // Plain WS requires no TLS configuration. + if url.scheme() == "ws" { + return Ok(ClientBuilder::new().build()); + } + + // From here: wss:// with TLS enabled. + let is_local = is_localhost_url(url); + + let tls_cfg: ClientConfig = if args.tls_disable_verify { + tracing::warn!("disabling TLS certificate verification"); + TlsClientConfigBuilder::new_insecure()?.build() + } else if let Some(path) = &args.tls_cert { + let certs = tls::cert::load_certs(path)?; + anyhow::ensure!(!certs.is_empty(), "could not find certificate"); + + let mut roots = RootCertStore::empty(); + for c in certs { + let _ = roots.add(c); + } + + ClientConfig::builder() + .with_root_certificates(roots) + .with_no_client_auth() + } else if is_local { + tracing::warn!( + "no --tls-cert provided and target looks local ({}); \ + using insecure mode for quick testing (equivalent to --tls-disable-verify)", + url + ); + TlsClientConfigBuilder::new_insecure()?.build() + } else { + TlsClientConfigBuilder::new_with_native_certs()?.build() + }; + + Ok(ClientBuilder::new() + .with_tls_config(tls_cfg) + .with_default_alpn() + .build()) +} + +/// Determine whether a URL points to a loopback host. +fn is_localhost_url(url: &Url) -> bool { + match url.host_str() { + Some("localhost") => true, + Some(host) => host == "127.0.0.1" || host == "::1", + None => false, + } +} diff --git a/websock-mux/examples/echo-server-mux.rs b/websock-mux/examples/echo-server-mux.rs new file mode 100644 index 0000000..60eebc1 --- /dev/null +++ b/websock-mux/examples/echo-server-mux.rs @@ -0,0 +1,110 @@ +//! Echo server example for the websock mux transport. +//! +//! This server accepts mux sessions and echoes bytes over each bidirectional stream. + +use clap::Parser; +use std::{fs, io, path}; +use tracing::Level; +use tracing_subscriber::FmtSubscriber; +use websock_mux::{Server, ServerBuilder}; + +/// Command-line arguments for the echo mux server. +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Bind address for the server. + #[arg(short, long, default_value = "127.0.0.1:9001")] + addr: std::net::SocketAddr, + + /// Use the certificates at this path, encoded as PEM (enables wss://). + #[arg(long)] + tls_cert: Option, + + /// Use the private key at this path, encoded as PEM (enables wss://). + #[arg(long)] + tls_key: Option, +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + // Initialize tracing subscriber for logging in this example. + let subscriber = FmtSubscriber::builder() + .with_max_level(Level::DEBUG) + .finish(); + tracing::subscriber::set_global_default(subscriber) + .expect("failed to set global tracing subscriber"); + + let args = Args::parse(); + + let mut builder = ServerBuilder::new() + .with_addr(args.addr) + .with_default_alpn(); + + match (&args.tls_cert, &args.tls_key) { + (Some(cert_path), Some(key_path)) => { + let (chain, key) = load_pem_cert_and_key(cert_path, key_path)?; + let cfg = rustls::ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(chain, key) + .map_err(|e| anyhow::anyhow!(e.to_string()))?; + builder = builder.with_tls_config(cfg); + tracing::info!("TLS enabled (wss://)"); + } + (None, None) => { + tracing::warn!("TLS disabled (ws://). Provide --tls-cert/--tls-key to enable wss://"); + } + _ => anyhow::bail!("both --tls-cert and --tls-key must be provided together, or neither"), + } + + let server: Server = builder.build().await?; + let scheme = if args.tls_cert.is_some() { "wss" } else { "ws" }; + tracing::info!("listening on {} ({}://{})", args.addr, scheme, args.addr); + + loop { + let session = server.accept().await?; + tracing::info!("accepted connection"); + tokio::spawn(async move { + loop { + let (send, mut recv) = match session.accept_bi().await { + Ok(streams) => streams, + Err(_) => break, + }; + tracing::info!("accepted bidirectional stream"); + tokio::spawn(async move { + let send = send; + while let Ok(Some(chunk)) = recv.read_chunk(1024).await { + let chunk_len = chunk.len(); + tracing::info!("Received chunk of size: {}", chunk_len); + if send.write_buf(chunk).await.is_err() { + break; + } + tracing::info!("Sent chunk of size: {}", chunk_len); + } + let _ = send.finish().await; + }); + } + }); + } +} + +/// Load a PEM-encoded certificate chain and private key from disk. +fn load_pem_cert_and_key( + cert_path: &path::Path, + key_path: &path::Path, +) -> anyhow::Result<( + Vec>, + rustls::pki_types::PrivateKeyDer<'static>, +)> { + let chain_file = fs::File::open(cert_path)?; + let mut chain_reader = io::BufReader::new(chain_file); + let chain: Vec> = + rustls_pemfile::certs(&mut chain_reader).collect::>()?; + anyhow::ensure!(!chain.is_empty(), "could not find certificate"); + + let key_file = fs::File::open(key_path)?; + let mut key_reader = io::BufReader::new(key_file); + let key = rustls_pemfile::private_key(&mut key_reader)? + .ok_or_else(|| anyhow::anyhow!("missing private key"))?; + + Ok((chain, key)) +} diff --git a/websock-mux/src/lib.rs b/websock-mux/src/lib.rs new file mode 100644 index 0000000..8de2e42 --- /dev/null +++ b/websock-mux/src/lib.rs @@ -0,0 +1,11 @@ +pub use websock_proto::*; + +#[cfg(any(not(target_arch = "wasm32"), target_os = "wasi"))] +#[path = "tungstenite.rs"] +mod websocket; + +#[cfg(all(target_arch = "wasm32", not(target_os = "wasi")))] +#[path = "wasm.rs"] +mod websocket; + +pub use websocket::*; diff --git a/websock-mux/src/tungstenite.rs b/websock-mux/src/tungstenite.rs new file mode 100644 index 0000000..c47354d --- /dev/null +++ b/websock-mux/src/tungstenite.rs @@ -0,0 +1 @@ +pub use websock_tungstenite_mux::*; diff --git a/websock-mux/src/wasm.rs b/websock-mux/src/wasm.rs new file mode 100644 index 0000000..4613d1e --- /dev/null +++ b/websock-mux/src/wasm.rs @@ -0,0 +1 @@ +pub use websock_wasm_mux::*; diff --git a/websock-proto/src/error.rs b/websock-proto/src/error.rs index 78c4a7f..d54b9e6 100644 --- a/websock-proto/src/error.rs +++ b/websock-proto/src/error.rs @@ -30,6 +30,12 @@ pub enum Error { #[error("unsupported: {0}")] Unsupported(String), + #[error("frame decode error: {0}")] + FrameDecode(String), + + #[error("stream id error: {0}")] + StreamId(String), + /// A catch-all error for unexpected failures. #[error("other error: {0}")] Other(String), diff --git a/websock-proto/src/lib.rs b/websock-proto/src/lib.rs index 736fc21..f44f5da 100644 --- a/websock-proto/src/lib.rs +++ b/websock-proto/src/lib.rs @@ -10,4 +10,4 @@ mod options; pub use bytes::Bytes; pub use error::{Error, Result}; pub use message::{CloseFrame, Message}; -pub use options::{ConnectOptions, ServerOptions}; +pub use options::{ConnectOptions, ServerOptions, default_ws_alpn}; diff --git a/websock-proto/src/options.rs b/websock-proto/src/options.rs index 31459b0..187e372 100644 --- a/websock-proto/src/options.rs +++ b/websock-proto/src/options.rs @@ -1,3 +1,6 @@ +pub const ALPN_HTTP_1_1: &[u8] = b"http/1.1"; +pub const ALPN_H2: &[u8] = b"h2"; + /// Connection configuration shared by native and WebAssembly transports. #[derive(Debug, Clone, Default)] pub struct ConnectOptions { @@ -17,3 +20,7 @@ pub struct ServerOptions { /// Additional response headers (native only). pub headers: Vec<(String, String)>, } + +pub fn default_ws_alpn() -> Vec> { + vec![ALPN_HTTP_1_1.to_vec(), ALPN_H2.to_vec()] +} diff --git a/websock-tungstenite-mux/Cargo.toml b/websock-tungstenite-mux/Cargo.toml new file mode 100644 index 0000000..6be528c --- /dev/null +++ b/websock-tungstenite-mux/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "websock-tungstenite-mux" +version.workspace = true +edition.workspace = true +authors.workspace = true +description = "Native WebSocket multiplexing layer for logical streams" +repository = "https://github.com/foctal/websock" +readme = "../README.md" +keywords = ["network", "websocket", "multiplex", "native", "wasm"] +categories = ["network-programming", "web-programming"] +license = "MIT" + +[dependencies] +bytes = { workspace = true } +thiserror = "2" +websock-proto = { workspace = true } +websock-mux-proto = { workspace = true } +tokio = { version = "1", features = ["rt", "sync", "macros"] } +tokio-rustls = { version = "0.26", default-features = false, features = ["ring"]} +rustls = { version = "0.23", default-features = false, features = ["ring", "std"] } +rustls-native-certs = "0.8" +rcgen = "0.14" +rustls-pemfile = "2.2" +tokio-tungstenite = { version = "0.28", features = ["rustls-tls-webpki-roots"] } +futures-util = { version = "0.3" } +futures-core = { version = "0.3" } +futures-sink = { version = "0.3" } +tokio-util = { version = "0.7" } +#http = "1" +tracing = "0.1" diff --git a/websock-tungstenite-mux/src/builder.rs b/websock-tungstenite-mux/src/builder.rs new file mode 100644 index 0000000..c451308 --- /dev/null +++ b/websock-tungstenite-mux/src/builder.rs @@ -0,0 +1,184 @@ +use rustls::{ClientConfig, ServerConfig}; +use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; +use std::sync::Arc; +use websock_proto::{Result, default_ws_alpn}; + +use crate::session::Limits; +use crate::{Client, Server, bind}; +use websock_mux_proto::SUBPROTOCOL; + +/// Builder for a mux WebSocket client. +#[derive(Debug, Clone)] +pub struct ClientBuilder { + pub(crate) opts: websock_proto::ConnectOptions, + pub(crate) tls: Option, + pub(crate) alpn: Option>>, + pub(crate) limits: Limits, +} + +impl Default for ClientBuilder { + fn default() -> Self { + Self::new() + } +} + +impl ClientBuilder { + pub fn new() -> Self { + let mut opts = websock_proto::ConnectOptions::default(); + opts.protocols.insert(0, SUBPROTOCOL.to_string()); + Self { + opts, + tls: None, + alpn: None, + limits: Limits::default(), + } + } + + pub fn with_options(mut self, opts: websock_proto::ConnectOptions) -> Self { + self.opts = opts; + if !self.opts.protocols.iter().any(|p| p == SUBPROTOCOL) { + self.opts.protocols.insert(0, SUBPROTOCOL.to_string()); + } + self + } + + /// Configure a custom rustls client config (for `wss://`). + pub fn with_tls_config(mut self, tls: ClientConfig) -> Self { + self.tls = Some(tls); + self + } + + pub fn with_default_alpn(mut self) -> Self { + self.alpn = Some(default_ws_alpn()); + self + } + + pub fn with_alpn_protocols(mut self, alpn: Vec>) -> Self { + self.alpn = Some(alpn); + self + } + + /// Configure session limits + pub fn with_limits(mut self, limits: Limits) -> Self { + self.limits = limits; + self + } + + pub fn with_header(mut self, name: impl Into, value: impl Into) -> Self { + self.opts.headers.push((name.into(), value.into())); + self + } + + pub fn with_protocol(mut self, protocol: impl Into) -> Self { + self.opts.protocols.push(protocol.into()); + self + } + + fn build_tls_config(&self) -> Option> { + let mut cfg = self.tls.clone()?; + + if let Some(alpn) = &self.alpn { + cfg.alpn_protocols = alpn.clone(); + } + + Some(Arc::new(cfg)) + } + + pub fn build(&self) -> Client { + Client { + opts: self.opts.clone(), + tls: self.build_tls_config(), + limits: self.limits.clone(), + } + } +} + +/// Builder for a mux WebSocket server. +#[derive(Debug, Clone)] +pub struct ServerBuilder { + pub(crate) addr: SocketAddr, + pub(crate) opts: websock_proto::ServerOptions, + pub(crate) tls: Option, + pub(crate) alpn: Option>>, + pub(crate) limits: Limits, +} + +impl ServerBuilder { + pub fn new() -> Self { + let mut opts = websock_proto::ServerOptions::default(); + opts.protocols.push(SUBPROTOCOL.to_string()); + Self { + addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)), + opts, + tls: None, + alpn: None, + limits: Limits::default(), + } + } + + /// Set the bind address. + pub fn with_addr(mut self, addr: impl Into) -> Self { + self.addr = addr.into(); + self + } + + pub fn with_options(mut self, opts: websock_proto::ServerOptions) -> Self { + self.opts = opts; + if !self.opts.protocols.iter().any(|p| p == SUBPROTOCOL) { + self.opts.protocols.insert(0, SUBPROTOCOL.to_string()); + } + self + } + + /// Configure TLS for incoming connections (accept `wss://`). + pub fn with_tls_config(mut self, tls: ServerConfig) -> Self { + self.tls = Some(tls); + self + } + + pub fn with_default_alpn(mut self) -> Self { + self.alpn = Some(default_ws_alpn()); + self + } + + pub fn with_alpn_protocols(mut self, alpn: Vec>) -> Self { + self.alpn = Some(alpn); + self + } + + /// Configure session limits + pub fn with_limits(mut self, limits: Limits) -> Self { + self.limits = limits; + self + } + + pub fn with_header(mut self, name: impl Into, value: impl Into) -> Self { + self.opts.headers.push((name.into(), value.into())); + self + } + + pub fn with_protocol(mut self, protocol: impl Into) -> Self { + self.opts.protocols.push(protocol.into()); + self + } + + fn build_tls_config(&self) -> Option { + let mut cfg = self.tls.clone()?; + + if let Some(alpn) = &self.alpn { + cfg.alpn_protocols = alpn.clone(); + } + + Some(cfg) + } + + pub async fn build(&self) -> Result { + bind( + self.addr, + self.opts.clone(), + self.build_tls_config(), + self.limits.clone(), + ) + .await + } +} diff --git a/websock-tungstenite-mux/src/client.rs b/websock-tungstenite-mux/src/client.rs new file mode 100644 index 0000000..cb46ec0 --- /dev/null +++ b/websock-tungstenite-mux/src/client.rs @@ -0,0 +1,104 @@ +use tokio_tungstenite::tungstenite; +use tungstenite::client::IntoClientRequest; +use tungstenite::http; +use tungstenite::http::header::SEC_WEBSOCKET_PROTOCOL; + +use std::sync::Arc; + +use rustls::ClientConfig; +use tokio_tungstenite::Connector; +use websock_proto::{Error, Result}; + +use crate::Session; +use crate::session::Limits; +use crate::session::map_tungstenite_err; +use websock_mux_proto::SUBPROTOCOL; + +fn negotiated_protocol(resp: &http::Response>>) -> Option<&str> { + resp.headers() + .get(SEC_WEBSOCKET_PROTOCOL)? + .to_str() + .ok() + .map(|s| s.trim()) +} + +fn is_protocol_ok(p: &str) -> bool { + p.eq_ignore_ascii_case(SUBPROTOCOL) +} + +fn validate_client_protocols(opts: &websock_proto::ConnectOptions) -> Result<()> { + for p in &opts.protocols { + tungstenite::http::HeaderValue::from_str(p) + .map_err(|e| Error::Protocol(format!("invalid protocol value: {e}")))?; + } + Ok(()) +} + +#[derive(Debug, Clone)] +pub struct Client { + pub(crate) opts: websock_proto::ConnectOptions, + pub(crate) tls: Option>, + pub(crate) limits: Limits, +} + +impl Client { + pub fn options(&self) -> &websock_proto::ConnectOptions { + &self.opts + } + + /// Return the configured TLS client config (if any). + pub fn tls_config(&self) -> Option<&Arc> { + self.tls.as_ref() + } + + pub async fn connect(&self, url: &str) -> Result { + self.connect_with_tls(url, self.tls.clone()).await + } + + /// Connect with an explicit TLS configuration. + /// + /// When `tls` is `None`, the default Tokio Tungstenite TLS settings are used + /// (and plain `ws://` works as-is). + pub async fn connect_with_tls( + &self, + url: &str, + tls: Option>, + ) -> Result { + validate_client_protocols(&self.opts)?; + + let mut request = url + .into_client_request() + .map_err(|e| websock_proto::Error::InvalidUrl(e.to_string()))?; + + let headers = request.headers_mut(); + for (k, v) in self.opts.headers.iter() { + let name = tungstenite::http::header::HeaderName::from_bytes(k.as_bytes()) + .map_err(|e| Error::Protocol(format!("invalid header name: {e}")))?; + let value = tungstenite::http::header::HeaderValue::from_str(&v) + .map_err(|e| Error::Protocol(format!("invalid header value: {e}")))?; + headers.append(name, value); + } + + // Apply subprotocols. + if !self.opts.protocols.is_empty() { + let joined = self.opts.protocols.join(","); + let value = tungstenite::http::header::HeaderValue::from_str(&joined) + .map_err(|e| Error::Protocol(format!("invalid protocol value: {e}")))?; + headers.insert(SEC_WEBSOCKET_PROTOCOL, value); + } + + let connector = tls.map(Connector::Rustls); + let (stream, response) = + tokio_tungstenite::connect_async_tls_with_config(request, None, false, connector) + .await + .map_err(map_tungstenite_err)?; + + let proto = negotiated_protocol(&response) + .ok_or_else(|| Error::Protocol("missing SEC_WEBSOCKET_PROTOCOL in response".into()))?; + if !is_protocol_ok(proto) { + return Err(Error::Protocol(format!("subprotocol mismatch: {proto}"))); + } + + Session::new(stream, false, self.limits.clone()) + } +} diff --git a/websock-tungstenite-mux/src/lib.rs b/websock-tungstenite-mux/src/lib.rs new file mode 100644 index 0000000..6cc0aca --- /dev/null +++ b/websock-tungstenite-mux/src/lib.rs @@ -0,0 +1,79 @@ +//! Tokio + tokio-tungstenite based WebSocket multiplexing transport. +//! +//! This crate provides a QUIC/WebTransport-like logical stream interface over a single WebSocket. + +mod builder; +mod client; +mod server; +mod session; +pub mod tls; + +pub use builder::{ClientBuilder, ServerBuilder}; +pub use client::Client; +pub use server::{Server, bind}; +pub use session::Limits; +pub use session::{RecvStream, SendStream, Session}; +pub use tls::{ + TlsClientConfig, TlsClientConfigBuilder, TlsConfig, TlsServerConfig, TlsServerConfigBuilder, +}; + +#[cfg(test)] +mod tests { + use bytes::{Bytes, BytesMut}; + use tokio::sync::mpsc; + use websock_mux_proto::{Frame, StreamDir, StreamId}; + + use crate::session::Limits; + use crate::session::SessionInner; + + #[test] + fn frame_roundtrip() { + let id = StreamId(4); + let frame = Frame::Stream { + id, + data: Bytes::from_static(b"hello"), + fin: true, + }; + let mut buf = frame.encode(); + let decoded = Frame::decode(&mut buf).expect("decode"); + assert_eq!(frame, decoded); + } + + #[tokio::test] + async fn stream_open_data_fin() { + let (outbound_tx, _outbound_rx) = mpsc::channel(4); + let (accept_uni_tx, mut accept_uni_rx) = mpsc::channel(4); + let (accept_bi_tx, _accept_bi_rx) = mpsc::channel(4); + let inner = std::sync::Arc::new(SessionInner::new( + false, + Limits::default(), + outbound_tx, + accept_uni_tx, + accept_bi_tx, + )); + let id = StreamId::new(0, true, StreamDir::Uni).expect("stream id"); + inner + .clone() + .handle_frame(Frame::OpenUni { id }) + .await + .expect("open"); + let mut recv = accept_uni_rx.recv().await.expect("recv stream"); + let data = Bytes::from_static(b"ping"); + inner + .clone() + .handle_frame(Frame::Stream { + id, + data: data.clone(), + fin: true, + }) + .await + .expect("stream"); + let mut buf = BytesMut::new(); + let n = recv.read_buf::(&mut buf).await.expect("read"); + assert_eq!(n, Some(4)); + assert_eq!(buf.as_ref(), data.as_ref()); + let max_size: usize = 1024; + let end = recv.read_chunk(max_size).await.expect("fin"); + assert!(end.is_none()); + } +} diff --git a/websock-tungstenite-mux/src/server.rs b/websock-tungstenite-mux/src/server.rs new file mode 100644 index 0000000..744fc38 --- /dev/null +++ b/websock-tungstenite-mux/src/server.rs @@ -0,0 +1,168 @@ +use std::sync::Arc; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::net::{TcpListener, ToSocketAddrs}; +use tokio_rustls::TlsAcceptor; +use tokio_tungstenite::accept_hdr_async; +use tokio_tungstenite::tungstenite; +use tokio_tungstenite::tungstenite::handshake::server; +use tungstenite::http; +use tungstenite::http::header::{HeaderName, HeaderValue, SEC_WEBSOCKET_PROTOCOL}; + +use websock_proto::{Error, Result}; + +use crate::Session; +use crate::session::Limits; +use crate::session::map_tungstenite_err; +use websock_mux_proto::SUBPROTOCOL; + +/// Marker trait for IO types usable by the server. +pub trait ServerIo: AsyncRead + AsyncWrite + Unpin + Send {} + +impl ServerIo for T where T: AsyncRead + AsyncWrite + Unpin + Send {} + +/// Boxed stream type used for server connections. +pub type ServerStream = Box; + +/// Convert configured headers into tungstenite types. +fn prepare_headers(opts: &websock_proto::ServerOptions) -> Result> { + let mut out = Vec::new(); + for (k, v) in &opts.headers { + let name = HeaderName::from_bytes(k.as_bytes()) + .map_err(|e| Error::Protocol(format!("invalid header name: {e}")))?; + let value = HeaderValue::from_str(v) + .map_err(|e| Error::Protocol(format!("invalid header value: {e}")))?; + out.push((name, value)); + } + Ok(out) +} + +/// Validate configured subprotocols before binding. +fn validate_protocols(opts: &websock_proto::ServerOptions) -> Result<()> { + for protocol in &opts.protocols { + HeaderValue::from_str(protocol) + .map_err(|e| Error::Protocol(format!("invalid protocol value: {e}")))?; + } + Ok(()) +} + +/// Select the first requested subprotocol that appears in the allowed list. +fn select_protocol(req: &server::Request, allowed: &[String]) -> Option { + if allowed.is_empty() { + return None; + } + let header = req.headers().get(SEC_WEBSOCKET_PROTOCOL)?; + let header = header.to_str().ok()?; + for candidate in header.split(',').map(|s| s.trim()) { + if allowed.iter().any(|p| p == candidate) { + return Some(candidate.to_string()); + } + } + None +} + +/// Bind a WebSocket server listener. +pub async fn bind( + addr: A, + opts: websock_proto::ServerOptions, + tls: Option, + limits: Limits, +) -> Result +where + A: ToSocketAddrs, +{ + let listener = TcpListener::bind(addr) + .await + .map_err(|e| Error::Io(e.to_string()))?; + let headers = prepare_headers(&opts)?; + validate_protocols(&opts)?; + + if !opts.protocols.iter().any(|p| p == SUBPROTOCOL) { + return Err(Error::Protocol( + "SUBPROTOCOL missing in ServerOptions".into(), + )); + } + + let acceptor = tls.map(|cfg| TlsAcceptor::from(Arc::new(cfg))); + + Ok(Server { + listener, + opts, + headers, + acceptor, + limits, + }) +} + +pub struct Server { + listener: TcpListener, + opts: websock_proto::ServerOptions, + headers: Vec<(HeaderName, HeaderValue)>, + acceptor: Option, + limits: Limits, +} + +impl Server { + pub async fn accept(&self) -> Result { + let (stream, _) = self + .listener + .accept() + .await + .map_err(|e| Error::Io(e.to_string()))?; + + let (stream, _is_tls): (ServerStream, bool) = if let Some(acceptor) = &self.acceptor { + let tls_stream = acceptor + .accept(stream) + .await + .map_err(|e| Error::Tls(e.to_string()))?; + (Box::new(tls_stream), true) + } else { + (Box::new(stream), false) + }; + + let headers = self.headers.clone(); + let allowed = self.opts.protocols.clone(); + + let ws = accept_hdr_async( + stream, + move |req: &server::Request, mut resp: server::Response| { + // Additional headers from configuration + for (name, value) in &headers { + resp.headers_mut().append(name, value.clone()); + } + + // websock-mux is required protocol + let Some(protocol) = select_protocol(req, &allowed) else { + return Err(http::Response::builder() + .status(http::StatusCode::BAD_REQUEST) + .body(Some(format!("'{SUBPROTOCOL}' protocol required"))) + .unwrap()); + }; + + // Confirm that the required protocol is present + if !protocol.eq_ignore_ascii_case(SUBPROTOCOL) { + return Err(http::Response::builder() + .status(http::StatusCode::BAD_REQUEST) + .body(Some(format!("'{SUBPROTOCOL}' protocol required"))) + .unwrap()); + } + + resp.headers_mut().insert( + http::header::SEC_WEBSOCKET_PROTOCOL, + http::HeaderValue::from_str(&protocol).expect("validated"), + ); + + Ok(resp) + }, + ) + .await + .map_err(map_tungstenite_err)?; + + Session::new(ws, true, self.limits.clone()) + } + + pub fn local_addr(&self) -> Result { + self.listener + .local_addr() + .map_err(|e| Error::Io(e.to_string())) + } +} diff --git a/websock-tungstenite-mux/src/session.rs b/websock-tungstenite-mux/src/session.rs new file mode 100644 index 0000000..f4b5bdf --- /dev/null +++ b/websock-tungstenite-mux/src/session.rs @@ -0,0 +1,620 @@ +use std::collections::HashMap; +use std::sync::{ + Arc, + atomic::{AtomicBool, AtomicU64, Ordering}, +}; + +use bytes::{BufMut, Bytes}; +use futures_util::{SinkExt, StreamExt}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::sync::{Mutex, mpsc}; +use tokio_tungstenite::tungstenite; +use websock_proto::{Error, Result}; + +use websock_mux_proto::stream::{Frame, StreamDir, StreamId}; + +/// Session limits to prevent unbounded buffering / DoS. +#[derive(Debug, Clone)] +pub struct Limits { + /// Maximum size of a single WebSocket binary message accepted by the inbound task. + pub max_ws_message_size: usize, + /// Maximum `Stream` frame payload size. + pub max_stream_data_per_frame: usize, + /// Maximum number of concurrently open receive streams. + pub max_open_streams: usize, + /// Per-stream receive event queue length. + pub recv_event_queue_len: usize, + /// Session outbound queue length. + pub outbound_queue_len: usize, + /// Queue length for accepting inbound uni streams. + pub accept_uni_queue_len: usize, + /// Queue length for accepting inbound bi streams. + pub accept_bi_queue_len: usize, +} + +impl Default for Limits { + fn default() -> Self { + Self { + // Safe defaults for a WebSocket fallback transport. + max_ws_message_size: 1 * 1024 * 1024, // 1 MiB + max_stream_data_per_frame: 256 * 1024, // 256 KiB + max_open_streams: 1024, + recv_event_queue_len: 128, + outbound_queue_len: 256, + accept_uni_queue_len: 128, + accept_bi_queue_len: 128, + } + } +} + +#[derive(Clone)] +pub struct Session { + inner: Arc, + accept_uni: Arc>>, + accept_bi: Arc>>, +} + +impl Session { + pub(crate) fn new( + stream: tokio_tungstenite::WebSocketStream, + is_server: bool, + limits: Limits, + ) -> Result + where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, + { + let (outbound_tx, outbound_rx) = mpsc::channel(limits.outbound_queue_len); + let (accept_uni_tx, accept_uni_rx) = mpsc::channel(limits.accept_uni_queue_len); + let (accept_bi_tx, accept_bi_rx) = mpsc::channel(limits.accept_bi_queue_len); + + let inner = Arc::new(SessionInner::new( + is_server, + limits, + outbound_tx, + accept_uni_tx, + accept_bi_tx, + )); + + let session = Self { + inner: inner.clone(), + accept_uni: Arc::new(Mutex::new(accept_uni_rx)), + accept_bi: Arc::new(Mutex::new(accept_bi_rx)), + }; + + inner.spawn_tasks(stream, outbound_rx); + Ok(session) + } + + pub async fn open_uni(&self) -> Result { + let id = self.inner.next_stream_id(StreamDir::Uni)?; + self.inner.send_frame(Frame::OpenUni { id }).await?; + Ok(SendStream::new(id, self.inner.clone())) + } + + pub async fn open_bi(&self) -> Result<(SendStream, RecvStream)> { + let id = self.inner.next_stream_id(StreamDir::Bi)?; + let recv = self.inner.register_recv_stream(id).await; + self.inner.send_frame(Frame::OpenBi { id }).await?; + Ok((SendStream::new(id, self.inner.clone()), recv)) + } + + pub async fn accept_uni(&self) -> Result { + let mut rx = self.accept_uni.lock().await; + rx.recv().await.ok_or(Error::Closed) + } + + pub async fn accept_bi(&self) -> Result<(SendStream, RecvStream)> { + let mut rx = self.accept_bi.lock().await; + rx.recv().await.ok_or(Error::Closed) + } +} + +#[derive(Clone)] +pub struct SendStream { + id: StreamId, + session: Arc, + finished: Arc, +} + +impl SendStream { + fn new(id: StreamId, session: Arc) -> Self { + Self { + id, + session, + finished: Arc::new(AtomicBool::new(false)), + } + } + + pub async fn write(&self, data: &[u8]) -> Result<()> { + self.write_buf(Bytes::copy_from_slice(data)).await + } + + pub async fn write_buf(&self, data: Bytes) -> Result<()> { + self.session + .send_frame(Frame::Stream { + id: self.id, + data, + fin: false, + }) + .await + } + + pub async fn write_all(&self, data: &[u8]) -> Result<()> { + self.write(data).await + } + + pub async fn finish(&self) -> Result<()> { + if self + .finished + .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst) + .is_ok() + { + self.session + .send_frame(Frame::Stream { + id: self.id, + data: Bytes::new(), + fin: true, + }) + .await?; + } + Ok(()) + } + + pub async fn reset(&self, code: u64) -> Result<()> { + self.finished.store(true, Ordering::SeqCst); + self.session + .send_frame(Frame::ResetStream { id: self.id, code }) + .await + } + + pub fn closed(&self) -> bool { + self.finished.load(Ordering::SeqCst) + } +} + +impl Drop for SendStream { + fn drop(&mut self) { + if !self.finished.load(Ordering::SeqCst) { + self.session.try_send_frame(Frame::ResetStream { + id: self.id, + code: 0, + }); + } + } +} + +#[derive(Debug)] +struct RecvEvent { + data: Bytes, + fin: bool, +} + +pub struct RecvStream { + id: StreamId, + session: Arc, + receiver: mpsc::Receiver, + finished: bool, + pending: Bytes, +} + +impl RecvStream { + fn new(id: StreamId, session: Arc, receiver: mpsc::Receiver) -> Self { + Self { + id, + session, + receiver, + finished: false, + pending: Bytes::new(), + } + } + + pub async fn read(&mut self, buf: &mut [u8]) -> Result> { + if self.finished { + return Ok(None); + } + if self.pending.is_empty() { + if let Some(chunk) = self.read_chunk_internal().await? { + self.pending = chunk; + } else { + return Ok(None); + } + } + let amt = buf.len().min(self.pending.len()); + buf[..amt].copy_from_slice(&self.pending[..amt]); + self.pending = self.pending.slice(amt..); + Ok(Some(amt)) + } + + pub async fn read_buf(&mut self, buf: &mut B) -> Result> { + let mut temp = vec![0u8; 4096]; + let size_opt = self.read(&mut temp).await?; + if let Some(size) = size_opt { + buf.put_slice(&temp[..size]); + Ok(Some(size)) + } else { + Ok(None) + } + } + + pub async fn read_chunk_internal(&mut self) -> Result> { + match self.receiver.recv().await { + Some(event) => { + if event.fin { + self.finished = true; + } + if event.data.is_empty() && event.fin { + Ok(None) + } else { + Ok(Some(event.data)) + } + } + None => { + self.finished = true; + Ok(None) + } + } + } + + pub async fn read_chunk(&mut self, max: usize) -> Result> { + match self.receiver.recv().await { + Some(mut event) => { + if event.fin { + self.finished = true; + } + if event.data.is_empty() && event.fin { + Ok(None) + } else if event.data.len() > max { + let chunk = event.data.split_to(max); + self.pending = event.data; + Ok(Some(chunk)) + } else { + Ok(Some(event.data)) + } + } + None => { + self.finished = true; + Ok(None) + } + } + } + + pub async fn stop(&self, code: u64) -> Result<()> { + self.session + .send_frame(Frame::StopSending { id: self.id, code }) + .await + } + + pub fn closed(&self) -> bool { + self.finished + } +} + +impl Drop for RecvStream { + fn drop(&mut self) { + if !self.finished { + self.session.try_send_frame(Frame::StopSending { + id: self.id, + code: 0, + }); + } + } +} + +pub(crate) struct SessionInner { + is_server: bool, + limits: Limits, + outbound_tx: mpsc::Sender, + accept_uni_tx: Mutex>>, + accept_bi_tx: Mutex>>, + streams: Mutex>>, + next_uni: AtomicU64, + next_bi: AtomicU64, + closed: AtomicBool, +} + +impl SessionInner { + pub(crate) fn new( + is_server: bool, + limits: Limits, + outbound_tx: mpsc::Sender, + accept_uni_tx: mpsc::Sender, + accept_bi_tx: mpsc::Sender<(SendStream, RecvStream)>, + ) -> Self { + Self { + is_server, + limits, + outbound_tx, + accept_uni_tx: Mutex::new(Some(accept_uni_tx)), + accept_bi_tx: Mutex::new(Some(accept_bi_tx)), + streams: Mutex::new(HashMap::new()), + next_uni: AtomicU64::new(0), + next_bi: AtomicU64::new(0), + closed: AtomicBool::new(false), + } + } + + fn spawn_tasks( + self: Arc, + stream: tokio_tungstenite::WebSocketStream, + mut outbound_rx: mpsc::Receiver, + ) where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, + { + let (mut ws_sink, mut ws_stream) = stream.split(); + + let inbound = self.clone(); + tokio::spawn(async move { + while let Some(msg) = ws_stream.next().await { + let msg = match msg { + Ok(m) => m, + Err(_) => break, + }; + + match msg { + tungstenite::Message::Binary(data) => { + if data.len() > inbound.limits.max_ws_message_size { + let _ = inbound.protocol_error(2, "ws message too large").await; + break; + } + let mut cursor = &data[..]; + let frame = match Frame::decode(&mut cursor) { + Ok(f) => f, + Err(_) => break, + }; + if inbound.clone().handle_frame(frame).await.is_err() { + break; + } + } + tungstenite::Message::Ping(p) => { + let _ = inbound.outbound_tx.try_send(tungstenite::Message::Pong(p)); + } + tungstenite::Message::Close(_) => break, + _ => {} + } + } + + inbound.close_all().await; + }); + + let outbound = self.clone(); + tokio::spawn(async move { + while let Some(msg) = outbound_rx.recv().await { + if ws_sink.send(msg).await.is_err() { + break; + } + } + outbound.close_all().await; + }); + } + + pub(crate) async fn handle_frame(self: Arc, frame: Frame) -> Result<()> { + match frame { + Frame::OpenUni { id } => { + if id.dir() != StreamDir::Uni { + return self + .protocol_error(1, "OpenUni with non-uni StreamId") + .await; + } + if id.initiator_is_server() != self.peer_is_server() { + return self.protocol_error(1, "OpenUni with wrong initiator").await; + } + + { + let streams = self.streams.lock().await; + if streams.contains_key(&id) { + return self.protocol_error(1, "duplicate stream open").await; + } + } + + { + let streams = self.streams.lock().await; + if streams.len() >= self.limits.max_open_streams { + return self.protocol_error(3, "too many open streams").await; + } + if streams.contains_key(&id) { + return self.protocol_error(1, "duplicate stream open").await; + } + } + + let recv = self.register_recv_stream(id).await; + + let tx = { self.accept_uni_tx.lock().await.clone() }; + if let Some(tx) = tx { + match tx.try_send(recv) { + Ok(()) => {} + Err(mpsc::error::TrySendError::Full(_)) => { + // Application is not accepting inbound streams fast enough. + // Reset this stream and keep the connection alive. + let mut streams = self.streams.lock().await; + streams.remove(&id); + self.try_send_frame(Frame::ResetStream { id, code: 3 }); + return Ok(()); + } + Err(mpsc::error::TrySendError::Closed(_)) => return Err(Error::Closed), + } + } else { + return Err(Error::Closed); + } + } + Frame::OpenBi { id } => { + if id.dir() != StreamDir::Bi { + return self.protocol_error(1, "OpenBi with non-bi StreamId").await; + } + if id.initiator_is_server() != self.peer_is_server() { + return self.protocol_error(1, "OpenBi with wrong initiator").await; + } + + { + let streams = self.streams.lock().await; + if streams.contains_key(&id) { + return self.protocol_error(1, "duplicate stream open").await; + } + } + + { + let streams = self.streams.lock().await; + if streams.len() >= self.limits.max_open_streams { + return self.protocol_error(3, "too many open streams").await; + } + if streams.contains_key(&id) { + return self.protocol_error(1, "duplicate stream open").await; + } + } + + let recv = self.register_recv_stream(id).await; + let send = SendStream::new(id, self.clone()); + + let tx = { self.accept_bi_tx.lock().await.clone() }; + if let Some(tx) = tx { + match tx.try_send((send, recv)) { + Ok(()) => {} + Err(mpsc::error::TrySendError::Full(_)) => { + let mut streams = self.streams.lock().await; + streams.remove(&id); + self.try_send_frame(Frame::ResetStream { id, code: 3 }); + return Ok(()); + } + Err(mpsc::error::TrySendError::Closed(_)) => return Err(Error::Closed), + } + } else { + return Err(Error::Closed); + } + } + Frame::Stream { id, data, fin } => { + if data.len() > self.limits.max_stream_data_per_frame { + return self.protocol_error(2, "stream data too large").await; + } + let mut streams = self.streams.lock().await; + + let Some(tx) = streams.get(&id) else { + return self + .protocol_error(1, "Stream data on unknown stream") + .await; + }; + + match tx.try_send(RecvEvent { data, fin }) { + Ok(()) => {} + Err(mpsc::error::TrySendError::Full(_)) => { + // Receiver is too slow. Reset just this stream to prevent + // unbounded buffering or blocking the inbound loop. + streams.remove(&id); + self.try_send_frame(Frame::ResetStream { id, code: 3 }); + return Ok(()); + } + Err(mpsc::error::TrySendError::Closed(_)) => { + streams.remove(&id); + return Ok(()); + } + } + + if fin { + streams.remove(&id); + } + } + Frame::ResetStream { id, .. } | Frame::StopSending { id, .. } => { + let mut streams = self.streams.lock().await; + if streams.remove(&id).is_none() { + return self.protocol_error(1, "reset/stop on unknown stream").await; + } + } + Frame::ConnectionClose { .. } => { + let mut streams = self.streams.lock().await; + streams.clear(); + } + } + Ok(()) + } + + pub(crate) async fn register_recv_stream(self: &Arc, id: StreamId) -> RecvStream { + let (tx, rx) = mpsc::channel(self.limits.recv_event_queue_len); + let mut streams = self.streams.lock().await; + streams.insert(id, tx); + RecvStream::new(id, self.clone(), rx) + } + + pub(crate) fn next_stream_id(&self, dir: StreamDir) -> Result { + let counter = match dir { + StreamDir::Uni => self.next_uni.fetch_add(1, Ordering::SeqCst), + StreamDir::Bi => self.next_bi.fetch_add(1, Ordering::SeqCst), + }; + StreamId::new(counter, self.is_server, dir).map_err(|e| Error::StreamId(e.to_string())) + } + + pub(crate) async fn send_frame(&self, frame: Frame) -> Result<()> { + if self.is_closed() { + return Err(Error::Closed); + } + let data = frame.encode().freeze(); + self.outbound_tx + .send(tungstenite::Message::Binary(data)) + .await + .map_err(|_| Error::Closed) + } + + pub(crate) fn try_send_frame(&self, frame: Frame) { + if self.is_closed() { + return; + } + let data = frame.encode().freeze(); + let _ = self + .outbound_tx + .try_send(tungstenite::Message::Binary(data)); + } + + pub(crate) async fn close_all(&self) { + if self.closed.swap(true, Ordering::SeqCst) { + return; + } + // Close accept channels + { + let mut tx = self.accept_uni_tx.lock().await; + tx.take(); + } + { + let mut tx = self.accept_bi_tx.lock().await; + tx.take(); + } + // Close existing streams + { + let mut streams = self.streams.lock().await; + streams.clear(); + } + } + + pub(crate) fn is_closed(&self) -> bool { + self.closed.load(Ordering::SeqCst) + } + + async fn protocol_error(self: &Arc, code: u64, reason: impl Into) -> Result<()> { + let reason = reason.into(); + + // Notify the peer (it's okay if sending fails...) + let _ = self.try_send_frame(Frame::ConnectionClose { + code, + reason: reason.clone(), + }); + + self.close_all().await; + Err(Error::Protocol(reason)) + } + + fn peer_is_server(&self) -> bool { + !self.is_server + } +} + +/// Map tungstenite errors into the shared error type. +pub(crate) fn map_tungstenite_err(e: tungstenite::Error) -> Error { + use tungstenite::Error as E; + match e { + E::ConnectionClosed | E::AlreadyClosed => Error::Closed, + E::Io(io) => Error::Io(io.to_string()), + E::Tls(tls) => Error::Tls(tls.to_string()), + E::Url(url) => Error::InvalidUrl(url.to_string()), + E::Protocol(err) => Error::Protocol(err.to_string()), + E::Utf8(err) => Error::Protocol(err), + E::Capacity(err) => Error::Protocol(err.to_string()), + E::HttpFormat(err) => Error::Protocol(err.to_string()), + other => Error::Other(other.to_string()), + } +} diff --git a/websock-tungstenite-mux/src/tls/cert.rs b/websock-tungstenite-mux/src/tls/cert.rs new file mode 100644 index 0000000..a1cbf0f --- /dev/null +++ b/websock-tungstenite-mux/src/tls/cert.rs @@ -0,0 +1,96 @@ +//! Certificate handling utilities. + +use rustls::client::danger::ServerCertVerifier; +use rustls::pki_types::{CertificateDer, ServerName, UnixTime}; +use std::{fs, path::Path, sync::Arc}; +use websock_proto::{Error, Result}; + +/// Load native certificates from the host system. +pub fn get_native_certs() -> Result { + let mut root_store = rustls::RootCertStore::empty(); + + let cert_result = rustls_native_certs::load_native_certs(); + + for cert in cert_result.certs { + let _ = root_store.add(cert); + } + + Ok(root_store) +} + +/// Load a certificate chain from a file (.der or .pem). +pub fn load_certs(cert_path: &Path) -> Result>> { + let cert_bytes = fs::read(cert_path).map_err(|e| Error::Io(e.to_string()))?; + + if cert_path.extension().map_or(false, |x| x == "der") { + return Ok(vec![CertificateDer::from(cert_bytes)]); + } + + rustls_pemfile::certs(&mut &*cert_bytes) + .collect::, std::io::Error>>() + .map_err(|e| Error::Io(e.to_string())) +} + +/// Certificate verifier that unconditionally accepts certificates. +/// +/// # Warning +/// This is vulnerable to MITM attacks and must only be used for testing. +#[derive(Debug)] +pub struct SkipServerVerification(Arc); + +impl SkipServerVerification { + /// Create a verifier using the default ring provider. + pub fn new() -> Arc { + Self::with_provider(Arc::new(rustls::crypto::ring::default_provider())) + } + + /// Create a verifier with the provided crypto provider. + pub fn with_provider(provider: Arc) -> Arc { + Arc::new(Self(provider)) + } +} + +impl ServerCertVerifier for SkipServerVerification { + fn verify_server_cert( + &self, + _end_entity: &CertificateDer<'_>, + _intermediates: &[CertificateDer<'_>], + _server_name: &ServerName<'_>, + _ocsp: &[u8], + _now: UnixTime, + ) -> std::result::Result { + Ok(rustls::client::danger::ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> std::result::Result { + rustls::crypto::verify_tls12_signature( + message, + cert, + dss, + &self.0.signature_verification_algorithms, + ) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> std::result::Result { + rustls::crypto::verify_tls13_signature( + message, + cert, + dss, + &self.0.signature_verification_algorithms, + ) + } + + fn supported_verify_schemes(&self) -> Vec { + self.0.signature_verification_algorithms.supported_schemes() + } +} diff --git a/websock-tungstenite-mux/src/tls/key.rs b/websock-tungstenite-mux/src/tls/key.rs new file mode 100644 index 0000000..10c8987 --- /dev/null +++ b/websock-tungstenite-mux/src/tls/key.rs @@ -0,0 +1,22 @@ +//! Private key handling utilities. + +use rustls::pki_types::{PrivateKeyDer, PrivatePkcs8KeyDer}; +use std::{fs, path::Path}; +use websock_proto::{Error, Result}; + +/// Load a private key from a file. +pub fn load_key(key_path: &Path) -> Result> { + let key = fs::read(key_path).map_err(|e| Error::Io(e.to_string()))?; + + let key = if key_path.extension().map_or(false, |x| x == "der") { + // Treat raw DER as PKCS#8. + PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key)) + } else { + // Decode PEM. + rustls_pemfile::private_key(&mut &*key) + .map_err(|e| Error::Tls(e.to_string()))? + .ok_or_else(|| Error::Io("no keys found".into()))? + }; + + Ok(key) +} diff --git a/websock-tungstenite-mux/src/tls/mod.rs b/websock-tungstenite-mux/src/tls/mod.rs new file mode 100644 index 0000000..6079528 --- /dev/null +++ b/websock-tungstenite-mux/src/tls/mod.rs @@ -0,0 +1,259 @@ +//! TLS configuration helpers for managing certificates and private keys. + +pub mod cert; +pub mod key; + +use websock_proto::{Error, Result}; + +use cert::SkipServerVerification; +use rustls::client::{ClientConfig, WebPkiServerVerifier}; +use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}; +use rustls::server::ServerConfig; +use std::path::Path; +use std::sync::Arc; + +/// Alias of [`rustls::server::ServerConfig`]. +pub type TlsServerConfig = rustls::server::ServerConfig; + +/// Alias of [`rustls::client::ClientConfig`]. +pub type TlsClientConfig = rustls::client::ClientConfig; + +/// Generate a self-signed certificate and private key (DER). +pub fn generate_self_signed_pair_der( + subject_alt_names: Vec, +) -> Result<(Vec>, PrivateKeyDer<'static>)> { + let cert = rcgen::generate_simple_self_signed(subject_alt_names) + .map_err(|e| Error::Tls(e.to_string()))?; + + let key = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(cert.signing_key.serialize_der())); + let cert_chain = vec![CertificateDer::from(cert.cert)]; + Ok((cert_chain, key)) +} + +/// Generate a self-signed certificate and private key (PEM). +pub fn generate_self_signed_pair_pem( + subject_alt_names: Vec, +) -> Result<(Vec, String)> { + let cert = rcgen::generate_simple_self_signed(subject_alt_names) + .map_err(|e| Error::Tls(e.to_string()))?; + + let key = cert.signing_key.serialize_pem(); + let cert_chain = vec![cert.cert.pem()]; + Ok((cert_chain, key)) +} + +/// Load certificate chain and private key from files. +pub fn load_cert( + cert_path: &Path, + key_path: &Path, +) -> Result<(Vec>, PrivateKeyDer<'static>)> { + let cert_chain = cert::load_certs(cert_path)?; + let key = key::load_key(key_path)?; + Ok((cert_chain, key)) +} + +/// Bundled TLS configuration for both client and server use. +#[derive(Debug, Clone)] +pub struct TlsConfig { + /// Client-side rustls configuration. + pub client_config: ClientConfig, + /// Server-side rustls configuration. + pub server_config: ServerConfig, +} + +impl TlsConfig { + /// Create a new TLS configuration with the specified certificate and private key. + pub fn with_cert(cert_path: &Path, key_path: &Path) -> Result { + let client_config = TlsClientConfigBuilder::new_with_native_certs()? + .with_alpn_protocols(vec![b"h3".to_vec()]) + .build(); + + let server_config = TlsServerConfigBuilder::new_with_cert(cert_path, key_path)? + .with_alpn_protocols(vec![b"h3".to_vec()]) + .build(); + + Ok(Self { + client_config, + server_config, + }) + } + + /// Create a new TLS configuration with self-signed certificates (localhost). + pub fn with_self_signed_certs() -> Result { + let client_config = TlsClientConfigBuilder::new_with_native_certs()? + .with_alpn_protocols(vec![b"h3".to_vec()]) + .build(); + + let server_config = + TlsServerConfigBuilder::new_with_self_signed_certs(vec!["localhost".into()])? + .with_alpn_protocols(vec![b"h3".to_vec()]) + .build(); + + Ok(Self { + client_config, + server_config, + }) + } + + /// Create a new TLS configuration with system certificates (server side is self-signed localhost). + pub fn new_native_config() -> Result { + let client_config = TlsClientConfigBuilder::new_with_native_certs()? + .with_alpn_protocols(vec![b"h3".to_vec()]) + .build(); + + let server_config = + TlsServerConfigBuilder::new_with_self_signed_certs(vec!["localhost".into()])? + .with_alpn_protocols(vec![b"h3".to_vec()]) + .build(); + + Ok(Self { + client_config, + server_config, + }) + } + + /// Create a new TLS configuration with no certificate verification (testing only). + pub fn new_insecure_config() -> Result { + let client_config = TlsClientConfigBuilder::new_insecure()? + .with_alpn_protocols(vec![b"h3".to_vec()]) + .build(); + + let server_config = + TlsServerConfigBuilder::new_with_self_signed_certs(vec!["localhost".into()])? + .with_alpn_protocols(vec![b"h3".to_vec()]) + .build(); + + Ok(Self { + client_config, + server_config, + }) + } +} + +/// Server config builder (owned builder). +#[derive(Debug, Clone)] +pub struct TlsServerConfigBuilder { + inner: TlsServerConfig, +} + +impl TlsServerConfigBuilder { + /// Create an insecure server config with a self-signed certificate. + pub fn new_insecure(subject_alt_names: Vec) -> Result { + let (certs, key) = generate_self_signed_pair_der(subject_alt_names)?; + let inner = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(certs, key) + .map_err(|e| Error::Tls(e.to_string()))?; + Ok(Self { inner }) + } + + /// Create a server config using certificate and private key files. + pub fn new_with_cert(cert_path: &Path, key_path: &Path) -> Result { + let (certs, key) = load_cert(cert_path, key_path)?; + let inner = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(certs, key) + .map_err(|e| Error::Tls(e.to_string()))?; + Ok(Self { inner }) + } + + /// Create a server config using a self-signed certificate. + pub fn new_with_self_signed_certs(subject_alt_names: Vec) -> Result { + Self::new_insecure(subject_alt_names) + } + + /// Set ALPN protocol identifiers. + pub fn with_alpn_protocols(mut self, protocols: Vec>) -> Self { + self.inner.alpn_protocols = protocols; + self + } + + /// Finalize the builder and return the server config. + pub fn build(self) -> TlsServerConfig { + self.inner + } +} + +/// Client config builder (owned builder). +#[derive(Debug, Clone)] +pub struct TlsClientConfigBuilder { + inner: TlsClientConfig, +} + +impl TlsClientConfigBuilder { + /// Create an insecure client config that skips certificate verification. + pub fn new_insecure() -> Result { + let inner = ClientConfig::builder() + .dangerous() + .with_custom_certificate_verifier(SkipServerVerification::new()) + .with_no_client_auth(); + Ok(Self { inner }) + } + + /// Create a client config that uses the system root store. + pub fn new_with_native_certs() -> Result { + let native_certs = cert::get_native_certs()?; + let inner = ClientConfig::builder() + .with_root_certificates(native_certs) + .with_no_client_auth(); + Ok(Self { inner }) + } + + /// Create a client config with a custom WebPKI verifier. + pub fn new_with_webpki_verifier(verifier: Arc) -> Result { + let inner = ClientConfig::builder() + .with_webpki_verifier(verifier) + .with_no_client_auth(); + Ok(Self { inner }) + } + + /// Set ALPN protocol identifiers. + pub fn with_alpn_protocols(mut self, protocols: Vec>) -> Self { + self.inner.alpn_protocols = protocols; + self + } + + /// Finalize the builder and return the client config. + pub fn build(self) -> TlsClientConfig { + self.inner + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_generate_self_signed_pair_der() { + let (cert_chain, key) = generate_self_signed_pair_der(vec!["localhost".into()]).unwrap(); + let rustls_server_config = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(cert_chain, key); + + if let Err(e) = rustls_server_config { + panic!("Failed to create ServerConfig: {e}"); + } + } + + #[test] + fn test_generate_self_signed_pair_pem() { + let (cert_chain, key) = generate_self_signed_pair_pem(vec!["localhost".into()]).unwrap(); + + let cert_path = Path::new("cert.pem"); + let key_path = Path::new("key.pem"); + std::fs::write(cert_path, cert_chain.join("\n")).unwrap(); + std::fs::write(key_path, key).unwrap(); + + let (cert_chain, key) = load_cert(cert_path, key_path).unwrap(); + let rustls_server_config = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(cert_chain, key); + + if let Err(e) = rustls_server_config { + panic!("Failed to create ServerConfig: {e}"); + } + + std::fs::remove_file(cert_path).unwrap(); + std::fs::remove_file(key_path).unwrap(); + } +} diff --git a/websock-tungstenite/src/builder.rs b/websock-tungstenite/src/builder.rs index b90e339..91663c5 100644 --- a/websock-tungstenite/src/builder.rs +++ b/websock-tungstenite/src/builder.rs @@ -1,19 +1,21 @@ //! Builders for clients and servers using the Tokio Tungstenite transport. +use crate::{Connection, Server}; use rustls::pki_types::{CertificateDer, PrivateKeyDer}; use rustls::{ClientConfig, RootCertStore, ServerConfig}; use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; use std::sync::Arc; +use websock_proto::default_ws_alpn; use websock_proto::{ConnectOptions, Error, Result, ServerOptions}; -use crate::{Connection, Server}; - /// Builder for creating a WebSocket client. /// /// The resulting client can be reused for multiple `connect()` calls. #[derive(Debug, Clone)] pub struct ClientBuilder { opts: ConnectOptions, + tls: Option, + alpn: Option>>, } impl Default for ClientBuilder { @@ -27,6 +29,8 @@ impl ClientBuilder { pub fn new() -> Self { Self { opts: ConnectOptions::default(), + tls: None, + alpn: None, } } @@ -78,12 +82,10 @@ impl ClientBuilder { self } - /// Build a reusable client without custom TLS configuration. - pub fn build(self) -> Client { - Client { - opts: self.opts, - tls: None, - } + /// Configure a custom rustls client config (for `wss://`). + pub fn with_tls_config(mut self, tls: ClientConfig) -> Self { + self.tls = Some(tls); + self } /// Build a client configured with the system trust store. @@ -120,6 +122,34 @@ impl ClientBuilder { pub fn dangerous(self) -> DangerousClientBuilder { DangerousClientBuilder { opts: self.opts } } + + pub fn with_default_alpn(mut self) -> Self { + self.alpn = Some(default_ws_alpn()); + self + } + + pub fn with_alpn_protocols(mut self, alpn: Vec>) -> Self { + self.alpn = Some(alpn); + self + } + + fn build_tls_config(&self) -> Option> { + let mut cfg = self.tls.clone()?; + + if let Some(alpn) = &self.alpn { + cfg.alpn_protocols = alpn.clone(); + } + + Some(Arc::new(cfg)) + } + + /// Build a client. + pub fn build(&self) -> Client { + Client { + opts: self.opts.clone(), + tls: self.build_tls_config(), + } + } } /// Reusable WebSocket client created by [`ClientBuilder`]. @@ -163,6 +193,7 @@ pub struct ServerBuilder { addr: SocketAddr, opts: ServerOptions, tls: Option, + alpn: Option>>, } impl Default for ServerBuilder { @@ -177,6 +208,7 @@ impl ServerBuilder { Self { addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)), opts: ServerOptions::default(), + alpn: None, tls: None, } } @@ -230,8 +262,28 @@ impl ServerBuilder { self } + pub fn with_default_alpn(mut self) -> Self { + self.alpn = Some(default_ws_alpn()); + self + } + + pub fn with_alpn_protocols(mut self, alpn: Vec>) -> Self { + self.alpn = Some(alpn); + self + } + + fn build_tls_config(&self) -> Option { + let mut cfg = self.tls.clone()?; + + if let Some(alpn) = &self.alpn { + cfg.alpn_protocols = alpn.clone(); + } + + Some(cfg) + } + /// Bind the listener and return a server instance. - pub async fn build(self) -> Result { - crate::server::bind(&self.addr, self.opts, self.tls).await + pub async fn build(&self) -> Result { + crate::server::bind(self.addr, self.opts.clone(), self.build_tls_config()).await } } diff --git a/websock-wasm-demo/Cargo.toml b/websock-wasm-demo/Cargo.toml index b10e7dd..b8cfc26 100644 --- a/websock-wasm-demo/Cargo.toml +++ b/websock-wasm-demo/Cargo.toml @@ -10,8 +10,9 @@ crate-type = ["cdylib"] [dependencies] websock = { path = "../websock" } +websock-mux = { path = "../websock-mux" } wasm-bindgen = "0.2" wasm-bindgen-futures = "0.4" console_error_panic_hook = "0.1" -web-sys = { version = "0.3", features = ["console", "Document", "HtmlButtonElement", "Window"] } +web-sys = { version = "0.3", features = ["console", "Document", "HtmlInputElement", "HtmlButtonElement", "Window"] } futures-util = { version = "0.3", features = ["sink"] } diff --git a/websock-wasm-demo/index.html b/websock-wasm-demo/index.html index 4737082..47ebf50 100644 --- a/websock-wasm-demo/index.html +++ b/websock-wasm-demo/index.html @@ -1,12 +1,38 @@ - + - websock wasm demo + + websock-wasm-demo + - - -

Open DevTools Console to see logs.

+

websock-wasm-demo

+ +
+ +
+ +
+ + + +
+ + + Tip: For TLS testing use wss:// and a certificate that your browser trusts. + + +

log

+

+
+    
   
-
\ No newline at end of file
+
diff --git a/websock-wasm-demo/src/echo.rs b/websock-wasm-demo/src/echo.rs
new file mode 100644
index 0000000..6739b44
--- /dev/null
+++ b/websock-wasm-demo/src/echo.rs
@@ -0,0 +1,60 @@
+use futures_util::{SinkExt, StreamExt};
+use websock::{ClientBuilder, Message};
+
+//pub const DEFAULT_ECHO_URL: &str = "wss://echo.websocket.org";
+
+pub async fn run_conn_demo(url: &str, log: impl Fn(&str)) {
+    log("[conn demo] start");
+
+    let client = ClientBuilder::new()
+        .with_system_roots()
+        .expect("client config failed");
+
+    log(&format!("[conn demo] Connecting to server at {url}..."));
+    log(&format!(
+        "[conn demo] Using WebSocket options: {:?}",
+        client.options()
+    ));
+
+    let mut conn = client.connect(url).await.expect("connect failed");
+
+    conn.send(Message::Text("hello from wasm".into()))
+        .await
+        .expect("send failed");
+
+    let msg = conn.recv().await.expect("recv failed");
+    log(&format!("[conn demo] got: {msg:?}"));
+
+    conn.close().await.expect("close failed");
+    log("[conn demo] done");
+}
+
+pub async fn run_split_demo(url: &str, log: impl Fn(&str)) {
+    log("[split demo] start");
+
+    let client = ClientBuilder::new()
+        .with_system_roots()
+        .expect("client config failed");
+
+    log(&format!("[split demo] Connecting to server at {url}..."));
+    log(&format!(
+        "[split demo] Using WebSocket options: {:?}",
+        client.options()
+    ));
+
+    let conn = client.connect(url).await.expect("connect failed");
+    let (mut sink, mut stream) = websock::stream::split(conn);
+
+    sink.send(Message::Text("hello from wasm split".into()))
+        .await
+        .expect("send failed");
+
+    let msg = stream
+        .next()
+        .await
+        .expect("stream closed")
+        .expect("recv failed");
+
+    log(&format!("[split demo] got: {msg:?}"));
+    log("[split demo] done");
+}
diff --git a/websock-wasm-demo/src/echo_mux.rs b/websock-wasm-demo/src/echo_mux.rs
new file mode 100644
index 0000000..6fd5dc6
--- /dev/null
+++ b/websock-wasm-demo/src/echo_mux.rs
@@ -0,0 +1,39 @@
+pub const DEFAULT_MUX_URL: &str = "wss://localhost:9001";
+
+pub async fn run_mux_bi_demo(url: &str, log: impl Fn(&str)) {
+    log("[mux bi demo] start");
+
+    let client = websock_mux::ClientBuilder::new()
+        .with_system_roots()
+        .expect("client config failed");
+
+    log(&format!(
+        "[mux bi demo] Connecting to mux server at {url}..."
+    ));
+    log(&format!(
+        "[mux bi demo] Using WebSocket options: {:?}",
+        client.options()
+    ));
+
+    // Mux Session
+    let session = match client.connect(url).await {
+        Ok(s) => s,
+        Err(e) => {
+            log(&format!("[mux bi demo] connect failed: {e:?}"));
+            return;
+        }
+    };
+
+    let (send, mut recv) = session.open_bi().expect("open_bi failed");
+
+    send.write(b"hello mux from wasm").expect("send failed");
+    send.finish().expect("finish failed");
+
+    let mut buf = vec![0u8; 1024];
+    while let Some(n) = recv.read(&mut buf).await.expect("read failed") {
+        let text = String::from_utf8_lossy(&buf[..n]);
+        log(&format!("[mux bi demo] recv: {}", text));
+    }
+
+    log("[mux bi demo] done");
+}
diff --git a/websock-wasm-demo/src/lib.rs b/websock-wasm-demo/src/lib.rs
index 595ddb4..ed59df0 100644
--- a/websock-wasm-demo/src/lib.rs
+++ b/websock-wasm-demo/src/lib.rs
@@ -1,103 +1,100 @@
-//! Minimal WebAssembly demo that exercises the websock API in the browser.
+//! Minimal WebAssembly demo for WebSocket + WebSocket-mux in the browser.
 
-use wasm_bindgen::prelude::*;
-use wasm_bindgen_futures::spawn_local;
-use web_sys::{Document, HtmlButtonElement};
+#![cfg(all(target_arch = "wasm32", not(target_os = "wasi")))]
 
-use futures_util::{SinkExt, StreamExt};
-use websock::{ClientBuilder, Message};
+mod echo;
+mod echo_mux;
 
-// Alternate echo server used for quick testing.
-const DEFAULT_ECHO_URL: &str = "wss://echo.websocket.org";
+use wasm_bindgen::JsCast;
+use wasm_bindgen::prelude::*;
+use wasm_bindgen_futures::spawn_local;
+use web_sys::{Document, HtmlButtonElement, HtmlInputElement};
 
-/// Obtain the active document.
 fn document() -> Document {
     web_sys::window().unwrap().document().unwrap()
 }
 
-/// Log a message to the browser console.
-fn log(msg: &str) {
-    web_sys::console::log_1(&msg.into());
+fn by_id(id: &str) -> T {
+    document()
+        .get_element_by_id(id)
+        .unwrap_or_else(|| panic!("missing element: #{id}"))
+        .dyn_into()
+        .unwrap()
 }
 
-/// Bind a click handler to a button by element ID.
-fn hook_button(id: &str, f: impl Fn() + 'static) {
+fn append_log(msg: &str) {
     let doc = document();
-    let el = doc
-        .get_element_by_id(id)
-        .unwrap_or_else(|| panic!("missing element: #{id}"));
-    let btn: HtmlButtonElement = el.dyn_into().unwrap();
-
-    let cb = Closure::::new(move || {
-        f();
-    });
+    let el = doc.get_element_by_id("log").unwrap();
+    let mut text = el.text_content().unwrap_or_default();
+    text.push_str(msg);
+    text.push('\n');
+    el.set_text_content(Some(&text));
 
-    btn.add_event_listener_with_callback("click", cb.as_ref().unchecked_ref())
-        .unwrap();
+    // Also mirror into DevTools for convenience.
+    web_sys::console::log_1(&msg.into());
+}
 
-    // Keep the handler alive for the lifetime of the page.
-    cb.forget();
+fn get_url() -> String {
+    let input: HtmlInputElement = by_id("url");
+    let url = input.value();
+    if url.trim().is_empty() {
+        echo_mux::DEFAULT_MUX_URL.to_string()
+    } else {
+        url
+    }
 }
 
 #[wasm_bindgen(start)]
-/// Entry point invoked by the browser when the module loads.
 pub fn start() {
     console_error_panic_hook::set_once();
 
-    hook_button("btn-conn", || {
-        spawn_local(async move {
-            log("[conn demo] start");
-
-            let client = ClientBuilder::new()
-                .with_system_roots()
-                .expect("client config failed");
-
-            let mut conn = client
-                .connect(DEFAULT_ECHO_URL)
-                .await
-                .expect("connect failed");
-
-            conn.send(Message::Text("hello from wasm".into()))
-                .await
-                .expect("send failed");
-
-            let msg = conn.recv().await.expect("recv failed");
-            log(&format!("[conn demo] got: {msg:?}"));
-
-            conn.close().await.expect("close failed");
-            log("[conn demo] done");
+    // Default URL.
+    {
+        let input: HtmlInputElement = by_id("url");
+        if input.value().trim().is_empty() {
+            input.set_value(echo_mux::DEFAULT_MUX_URL);
+        }
+    }
+
+    let btn_conn: HtmlButtonElement = by_id("btn-conn");
+    let btn_split: HtmlButtonElement = by_id("btn-split");
+    let btn_mux_bi: HtmlButtonElement = by_id("btn-mux-bi");
+
+    // conn demo
+    {
+        let cb = Closure::::new(move || {
+            let url = get_url();
+            spawn_local(async move {
+                echo::run_conn_demo(&url, |m| append_log(m)).await;
+            });
         });
-    });
-
-    hook_button("btn-split", || {
-        spawn_local(async move {
-            log("[split demo] start");
-
-            let client = ClientBuilder::new()
-                .with_system_roots()
-                .expect("client config failed");
-
-            let conn = client
-                .connect(DEFAULT_ECHO_URL)
-                .await
-                .expect("connect failed");
-
-            let (mut sink, mut stream) = websock::stream::split(conn);
-
-            sink.send(Message::Text("hello from wasm split".into()))
-                .await
-                .expect("send failed");
-
-            let msg = stream
-                .next()
-                .await
-                .expect("stream closed")
-                .expect("recv failed");
-
-            log(&format!("[split demo] got: {msg:?}"));
-            log("[split demo] done");
+        btn_conn.set_onclick(Some(cb.as_ref().unchecked_ref()));
+        cb.forget();
+    }
+
+    // split demo
+    {
+        let cb = Closure::::new(move || {
+            let url = get_url();
+            spawn_local(async move {
+                echo::run_split_demo(&url, |m| append_log(m)).await;
+            });
+        });
+        btn_split.set_onclick(Some(cb.as_ref().unchecked_ref()));
+        cb.forget();
+    }
+
+    // mux bi demo
+    {
+        let cb = Closure::::new(move || {
+            let url = get_url();
+            spawn_local(async move {
+                echo_mux::run_mux_bi_demo(&url, |m| append_log(m)).await;
+            });
         });
-    });
+        btn_mux_bi.set_onclick(Some(cb.as_ref().unchecked_ref()));
+        cb.forget();
+    }
 
-    log("ready: click [conn demo] or [split demo]");
+    append_log("ready");
 }
diff --git a/websock-wasm-mux/Cargo.toml b/websock-wasm-mux/Cargo.toml
new file mode 100644
index 0000000..b8e3ad6
--- /dev/null
+++ b/websock-wasm-mux/Cargo.toml
@@ -0,0 +1,20 @@
+[package]
+name = "websock-wasm-mux"
+version.workspace = true
+edition.workspace = true
+authors.workspace = true
+description = "WebAssembly WebSocket multiplexing transport (WebTransport-like streams over WebSocket)."
+repository = "https://github.com/foctal/websock"
+readme = "../README.md"
+keywords = ["network", "websocket", "multiplex", "wasm"]
+categories = ["network-programming", "web-programming"]
+license = "MIT"
+
+[dependencies]
+websock-proto = { workspace = true }
+websock-mux-proto = { workspace = true }
+websock-wasm = { workspace = true }
+bytes = { workspace = true }
+futures-channel = "0.3"
+futures-util = { version = "0.3", default-features = false, features = ["alloc"] }
+wasm-bindgen-futures = "0.4"
diff --git a/websock-wasm-mux/src/builder.rs b/websock-wasm-mux/src/builder.rs
new file mode 100644
index 0000000..4cb330e
--- /dev/null
+++ b/websock-wasm-mux/src/builder.rs
@@ -0,0 +1,86 @@
+//! Builder for browser mux clients.
+
+use crate::{Client, Limits};
+use websock_mux_proto::SUBPROTOCOL;
+use websock_proto::{ConnectOptions, Result};
+
+/// Builder for creating a browser WebSocket mux client.
+///
+/// This wraps `websock-wasm` and enforces the mux `SUBPROTOCOL`.
+#[derive(Debug, Clone)]
+pub struct ClientBuilder {
+    opts: ConnectOptions,
+    limits: Limits,
+}
+
+impl Default for ClientBuilder {
+    fn default() -> Self {
+        Self::new()
+    }
+}
+
+impl ClientBuilder {
+    /// Create a new builder with default options.
+    pub fn new() -> Self {
+        let mut opts = ConnectOptions::default();
+        opts.protocols.push(SUBPROTOCOL.to_string());
+        Self {
+            opts,
+            limits: Limits::default(),
+        }
+    }
+
+    /// Replace the builder options wholesale.
+    ///
+    /// Note: this will re-append the mux subprotocol if missing.
+    pub fn with_options(mut self, mut opts: ConnectOptions) -> Self {
+        if !opts.protocols.iter().any(|p| p == SUBPROTOCOL) {
+            opts.protocols.push(SUBPROTOCOL.to_string());
+        }
+        self.opts = opts;
+        self
+    }
+
+    /// Return a reference to the current options.
+    pub fn options(&self) -> &ConnectOptions {
+        &self.opts
+    }
+
+    /// Add a single header to the connection request.
+    pub fn with_header(mut self, name: impl Into, value: impl Into) -> Self {
+        self.opts.headers.push((name.into(), value.into()));
+        self
+    }
+
+    /// Add multiple headers to the connection request.
+    pub fn with_headers(mut self, headers: I) -> Self
+    where
+        I: IntoIterator,
+        K: Into,
+        V: Into,
+    {
+        for (k, v) in headers {
+            self.opts.headers.push((k.into(), v.into()));
+        }
+        self
+    }
+
+    /// Configure session limits.
+    pub fn limits(mut self, limits: Limits) -> Self {
+        self.limits = limits;
+        self
+    }
+
+    /// Build a reusable client.
+    pub fn build(self) -> Client {
+        Client {
+            opts: self.opts,
+            limits: self.limits,
+        }
+    }
+
+    /// Build a client (no-op for browser cert roots).
+    pub fn with_system_roots(self) -> Result {
+        Ok(self.build())
+    }
+}
diff --git a/websock-wasm-mux/src/client.rs b/websock-wasm-mux/src/client.rs
new file mode 100644
index 0000000..02015d8
--- /dev/null
+++ b/websock-wasm-mux/src/client.rs
@@ -0,0 +1,24 @@
+//! Browser mux client.
+
+use crate::{Limits, Session};
+use websock_proto::{ConnectOptions, Result};
+
+/// Reusable browser WebSocket mux client created by [`crate::ClientBuilder`].
+#[derive(Debug, Clone)]
+pub struct Client {
+    pub(crate) opts: ConnectOptions,
+    pub(crate) limits: Limits,
+}
+
+impl Client {
+    /// Return a reference to the configured connection options.
+    pub fn options(&self) -> &ConnectOptions {
+        &self.opts
+    }
+
+    /// Establish a browser WebSocket connection and create a mux [`Session`].
+    pub async fn connect(&self, url: &str) -> Result {
+        let conn = websock_wasm::connect(url, self.opts.clone()).await?;
+        Ok(Session::new(conn, self.limits.clone()))
+    }
+}
diff --git a/websock-wasm-mux/src/lib.rs b/websock-wasm-mux/src/lib.rs
new file mode 100644
index 0000000..e8e15f3
--- /dev/null
+++ b/websock-wasm-mux/src/lib.rs
@@ -0,0 +1,12 @@
+//! Browser (wasm32) WebSocket multiplexing transport.
+//!
+//! This crate provides a QUIC/WebTransport-like *logical stream* interface over a single
+//! browser WebSocket connection.
+
+mod builder;
+mod client;
+mod session;
+
+pub use builder::ClientBuilder;
+pub use client::Client;
+pub use session::{Limits, RecvStream, SendStream, Session};
diff --git a/websock-wasm-mux/src/session.rs b/websock-wasm-mux/src/session.rs
new file mode 100644
index 0000000..a14b469
--- /dev/null
+++ b/websock-wasm-mux/src/session.rs
@@ -0,0 +1,560 @@
+use std::cell::RefCell;
+use std::collections::HashMap;
+use std::rc::Rc;
+use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
+
+use bytes::{BufMut, Bytes};
+use futures_channel::mpsc;
+use futures_util::lock::Mutex;
+use futures_util::{FutureExt, StreamExt};
+use wasm_bindgen_futures::spawn_local;
+use websock_proto::{Error, Message, Result};
+
+use websock_mux_proto::{Frame, StreamDir, StreamId};
+
+/// Session limits to prevent unbounded buffering / DoS.
+#[derive(Debug, Clone)]
+pub struct Limits {
+    /// Maximum size of a single WebSocket binary message accepted by the inbound loop.
+    pub max_ws_message_size: usize,
+    /// Maximum `Stream` frame payload size.
+    pub max_stream_data_per_frame: usize,
+    /// Maximum number of concurrently open receive streams.
+    pub max_open_streams: usize,
+    /// Per-stream receive event queue length.
+    pub recv_event_queue_len: usize,
+    /// Session outbound queue length.
+    pub outbound_queue_len: usize,
+    /// Queue length for accepting inbound uni streams.
+    pub accept_uni_queue_len: usize,
+    /// Queue length for accepting inbound bi streams.
+    pub accept_bi_queue_len: usize,
+}
+
+impl Default for Limits {
+    fn default() -> Self {
+        Self {
+            max_ws_message_size: 1 * 1024 * 1024,
+            max_stream_data_per_frame: 256 * 1024,
+            max_open_streams: 1024,
+            recv_event_queue_len: 128,
+            outbound_queue_len: 256,
+            accept_uni_queue_len: 128,
+            accept_bi_queue_len: 128,
+        }
+    }
+}
+
+#[derive(Clone)]
+pub struct Session {
+    inner: Rc,
+    accept_uni: Rc>>,
+    accept_bi: Rc>>,
+}
+
+impl Session {
+    pub(crate) fn new(conn: websock_wasm::Connection, limits: Limits) -> Self {
+        let (outbound_tx, outbound_rx) = mpsc::channel::(limits.outbound_queue_len);
+        let (accept_uni_tx, accept_uni_rx) =
+            mpsc::channel::(limits.accept_uni_queue_len);
+        let (accept_bi_tx, accept_bi_rx) =
+            mpsc::channel::<(SendStream, RecvStream)>(limits.accept_bi_queue_len);
+
+        let inner = Rc::new(SessionInner::new(
+            limits,
+            outbound_tx,
+            accept_uni_tx,
+            accept_bi_tx,
+        ));
+
+        let session = Self {
+            inner: inner.clone(),
+            accept_uni: Rc::new(Mutex::new(accept_uni_rx)),
+            accept_bi: Rc::new(Mutex::new(accept_bi_rx)),
+        };
+
+        inner.spawn_task(conn, outbound_rx);
+        session
+    }
+
+    pub fn open_uni(&self) -> Result {
+        let id = self.inner.next_stream_id(StreamDir::Uni)?;
+        self.inner.send_frame(Frame::OpenUni { id })?;
+        Ok(SendStream::new(id, self.inner.clone()))
+    }
+
+    pub fn open_bi(&self) -> Result<(SendStream, RecvStream)> {
+        let id = self.inner.next_stream_id(StreamDir::Bi)?;
+        let recv = self.inner.clone().register_recv_stream(id);
+        self.inner.send_frame(Frame::OpenBi { id })?;
+        Ok((SendStream::new(id, self.inner.clone()), recv))
+    }
+
+    pub async fn accept_uni(&self) -> Result {
+        let mut rx = self.accept_uni.lock().await;
+        rx.next().await.ok_or(Error::Closed)
+    }
+
+    pub async fn accept_bi(&self) -> Result<(SendStream, RecvStream)> {
+        let mut rx = self.accept_bi.lock().await;
+        rx.next().await.ok_or(Error::Closed)
+    }
+}
+
+#[derive(Clone)]
+pub struct SendStream {
+    id: StreamId,
+    session: Rc,
+    finished: Rc,
+}
+
+impl SendStream {
+    fn new(id: StreamId, session: Rc) -> Self {
+        Self {
+            id,
+            session,
+            finished: Rc::new(AtomicBool::new(false)),
+        }
+    }
+
+    pub fn write(&self, data: &[u8]) -> Result<()> {
+        self.write_buf(Bytes::copy_from_slice(data))
+    }
+
+    pub fn write_buf(&self, data: Bytes) -> Result<()> {
+        self.session.send_frame(Frame::Stream {
+            id: self.id,
+            data,
+            fin: false,
+        })
+    }
+
+    pub fn write_all(&self, data: &[u8]) -> Result<()> {
+        self.write(data)
+    }
+
+    pub fn finish(&self) -> Result<()> {
+        if self
+            .finished
+            .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
+            .is_ok()
+        {
+            self.session.send_frame(Frame::Stream {
+                id: self.id,
+                data: Bytes::new(),
+                fin: true,
+            })?;
+        }
+        Ok(())
+    }
+
+    pub fn reset(&self, code: u64) -> Result<()> {
+        self.finished.store(true, Ordering::SeqCst);
+        self.session
+            .send_frame(Frame::ResetStream { id: self.id, code })
+    }
+
+    pub fn closed(&self) -> bool {
+        self.finished.load(Ordering::SeqCst)
+    }
+}
+
+impl Drop for SendStream {
+    fn drop(&mut self) {
+        if !self.finished.load(Ordering::SeqCst) {
+            let _ = self.session.try_send_frame(Frame::ResetStream {
+                id: self.id,
+                code: 0,
+            });
+        }
+    }
+}
+
+#[derive(Debug)]
+struct RecvEvent {
+    data: Bytes,
+    fin: bool,
+}
+
+pub struct RecvStream {
+    id: StreamId,
+    session: Rc,
+    receiver: mpsc::Receiver,
+    finished: bool,
+    pending: Bytes,
+}
+
+impl RecvStream {
+    fn new(id: StreamId, session: Rc, receiver: mpsc::Receiver) -> Self {
+        Self {
+            id,
+            session,
+            receiver,
+            finished: false,
+            pending: Bytes::new(),
+        }
+    }
+
+    pub async fn read(&mut self, buf: &mut [u8]) -> Result> {
+        if self.finished {
+            return Ok(None);
+        }
+        if self.pending.is_empty() {
+            if let Some(chunk) = self.read_chunk_internal().await? {
+                self.pending = chunk;
+            } else {
+                return Ok(None);
+            }
+        }
+        let amt = buf.len().min(self.pending.len());
+        buf[..amt].copy_from_slice(&self.pending[..amt]);
+        self.pending = self.pending.slice(amt..);
+        Ok(Some(amt))
+    }
+
+    pub async fn read_buf(&mut self, buf: &mut B) -> Result> {
+        let mut temp = vec![0u8; 4096];
+        let size_opt = self.read(&mut temp).await?;
+        if let Some(size) = size_opt {
+            buf.put_slice(&temp[..size]);
+            Ok(Some(size))
+        } else {
+            Ok(None)
+        }
+    }
+
+    async fn read_chunk_internal(&mut self) -> Result> {
+        match self.receiver.next().await {
+            Some(event) => {
+                if event.fin {
+                    self.finished = true;
+                }
+                if event.data.is_empty() && event.fin {
+                    Ok(None)
+                } else {
+                    Ok(Some(event.data))
+                }
+            }
+            None => {
+                self.finished = true;
+                Ok(None)
+            }
+        }
+    }
+
+    pub async fn read_chunk(&mut self, max: usize) -> Result> {
+        match self.receiver.next().await {
+            Some(mut event) => {
+                if event.fin {
+                    self.finished = true;
+                }
+                if event.data.is_empty() && event.fin {
+                    Ok(None)
+                } else if event.data.len() > max {
+                    let chunk = event.data.split_to(max);
+                    self.pending = event.data;
+                    Ok(Some(chunk))
+                } else {
+                    Ok(Some(event.data))
+                }
+            }
+            None => {
+                self.finished = true;
+                Ok(None)
+            }
+        }
+    }
+
+    pub fn stop(&self, code: u64) -> Result<()> {
+        self.session
+            .send_frame(Frame::StopSending { id: self.id, code })
+    }
+
+    pub fn closed(&self) -> bool {
+        self.finished
+    }
+}
+
+impl Drop for RecvStream {
+    fn drop(&mut self) {
+        if !self.finished {
+            let _ = self.session.try_send_frame(Frame::StopSending {
+                id: self.id,
+                code: 0,
+            });
+        }
+    }
+}
+
+struct SessionInner {
+    limits: Limits,
+    outbound_tx: RefCell>,
+    accept_uni_tx: Mutex>>,
+    accept_bi_tx: Mutex>>,
+    streams: RefCell>>,
+    next_uni: AtomicU64,
+    next_bi: AtomicU64,
+    closed: AtomicBool,
+}
+
+impl SessionInner {
+    fn new(
+        limits: Limits,
+        outbound_tx: mpsc::Sender,
+        accept_uni_tx: mpsc::Sender,
+        accept_bi_tx: mpsc::Sender<(SendStream, RecvStream)>,
+    ) -> Self {
+        Self {
+            limits,
+            outbound_tx: RefCell::new(outbound_tx),
+            accept_uni_tx: Mutex::new(Some(accept_uni_tx)),
+            accept_bi_tx: Mutex::new(Some(accept_bi_tx)),
+            streams: RefCell::new(HashMap::new()),
+            next_uni: AtomicU64::new(0),
+            next_bi: AtomicU64::new(0),
+            closed: AtomicBool::new(false),
+        }
+    }
+
+    fn spawn_task(
+        self: Rc,
+        mut conn: websock_wasm::Connection,
+        mut outbound_rx: mpsc::Receiver,
+    ) {
+        let inner = self.clone();
+        spawn_local(async move {
+            loop {
+                futures_util::select! {
+                    msg = conn.recv().fuse() => {
+                        match msg {
+                            Ok(Message::Binary(data)) => {
+                                if data.len() > inner.limits.max_ws_message_size {
+                                    let _ = inner.protocol_error(2, "ws message too large").await;
+                                    break;
+                                }
+                                let mut cursor = &data[..];
+                                let frame = match Frame::decode(&mut cursor) {
+                                    Ok(f) => f,
+                                    Err(_) => {
+                                        let _ = inner.protocol_error(1, "invalid frame").await;
+                                        break;
+                                    }
+                                };
+                                if inner.clone().handle_frame(frame).await.is_err() {
+                                    break;
+                                }
+                            }
+                            Ok(Message::Text(_)) => {
+                                let _ = inner.protocol_error(1, "text message not supported").await;
+                                break;
+                            }
+                            Err(_) => break,
+                        }
+                    }
+                    out = outbound_rx.next().fuse() => {
+                        match out {
+                            Some(m) => {
+                                if conn.send(m).await.is_err() {
+                                    break;
+                                }
+                            }
+                            None => break,
+                        }
+                    }
+                }
+            }
+
+            inner.close_all().await;
+            let _ = conn.close().await;
+        });
+    }
+
+    fn next_stream_id(&self, dir: StreamDir) -> Result {
+        let is_server = false; // browser is always client
+        let n = match dir {
+            StreamDir::Uni => self.next_uni.fetch_add(1, Ordering::SeqCst),
+            StreamDir::Bi => self.next_bi.fetch_add(1, Ordering::SeqCst),
+        };
+        StreamId::new(n, is_server, dir)
+            .map_err(|e| Error::Protocol(format!("stream id overflow: {}", e)))
+    }
+
+    async fn handle_frame(self: Rc, frame: Frame) -> Result<()> {
+        match frame {
+            Frame::OpenUni { id } => {
+                if id.dir() != StreamDir::Uni {
+                    return self
+                        .protocol_error(1, "OpenUni with non-uni StreamId")
+                        .await;
+                }
+                // In browser wasm, we are always the client, so the peer is the server.
+                // Therefore, inbound streams must be server-initiated.
+                if !id.initiator_is_server() {
+                    return self.protocol_error(1, "OpenUni with wrong initiator").await;
+                }
+
+                let mut map = self.streams.borrow_mut();
+                if map.len() >= self.limits.max_open_streams {
+                    drop(map);
+                    return self.protocol_error(3, "too many open streams").await;
+                }
+                if map.contains_key(&id) {
+                    drop(map);
+                    return self.protocol_error(1, "duplicate stream open").await;
+                }
+
+                let recv = Self::register_recv_stream_locked(&self, &mut map, id);
+                drop(map);
+
+                let tx = self.accept_uni_tx.lock().await.clone();
+                if let Some(mut tx) = tx {
+                    match tx.try_send(recv) {
+                        Ok(()) => Ok(()),
+                        Err(e) => {
+                            if e.is_full() {
+                                self.streams.borrow_mut().remove(&id);
+                                let _ = self.try_send_frame(Frame::ResetStream { id, code: 3 });
+                                Ok(())
+                            } else {
+                                Err(Error::Closed)
+                            }
+                        }
+                    }
+                } else {
+                    Err(Error::Closed)
+                }
+            }
+            Frame::OpenBi { id } => {
+                if id.dir() != StreamDir::Bi {
+                    return self.protocol_error(1, "OpenBi with non-bi StreamId").await;
+                }
+                if !id.initiator_is_server() {
+                    return self.protocol_error(1, "OpenBi with wrong initiator").await;
+                }
+
+                let mut map = self.streams.borrow_mut();
+                if map.len() >= self.limits.max_open_streams {
+                    drop(map);
+                    return self.protocol_error(3, "too many open streams").await;
+                }
+                if map.contains_key(&id) {
+                    drop(map);
+                    return self.protocol_error(1, "duplicate stream open").await;
+                }
+
+                let recv = Self::register_recv_stream_locked(&self, &mut map, id);
+                drop(map);
+
+                let send = SendStream::new(id, self.clone());
+                let tx = self.accept_bi_tx.lock().await.clone();
+                if let Some(mut tx) = tx {
+                    match tx.try_send((send, recv)) {
+                        Ok(()) => Ok(()),
+                        Err(e) => {
+                            if e.is_full() {
+                                self.streams.borrow_mut().remove(&id);
+                                let _ = self.try_send_frame(Frame::ResetStream { id, code: 3 });
+                                Ok(())
+                            } else {
+                                Err(Error::Closed)
+                            }
+                        }
+                    }
+                } else {
+                    Err(Error::Closed)
+                }
+            }
+            Frame::Stream { id, data, fin } => {
+                if data.len() > self.limits.max_stream_data_per_frame {
+                    return self.protocol_error(2, "stream data too large").await;
+                }
+
+                let mut map = self.streams.borrow_mut();
+                let Some(tx) = map.get_mut(&id) else {
+                    drop(map);
+                    return self
+                        .protocol_error(1, "Stream data on unknown stream")
+                        .await;
+                };
+
+                match tx.try_send(RecvEvent { data, fin }) {
+                    Ok(()) => {}
+                    Err(e) => {
+                        if e.is_full() {
+                            map.remove(&id);
+                            drop(map);
+                            let _ = self.try_send_frame(Frame::ResetStream { id, code: 3 });
+                            return Ok(());
+                        } else {
+                            map.remove(&id);
+                            return Ok(());
+                        }
+                    }
+                }
+
+                if fin {
+                    map.remove(&id);
+                }
+                Ok(())
+            }
+            Frame::ResetStream { id, .. } | Frame::StopSending { id, .. } => {
+                self.streams.borrow_mut().remove(&id);
+                Ok(())
+            }
+            Frame::ConnectionClose { .. } => {
+                self.close_all().await;
+                Err(Error::Closed)
+            }
+        }
+    }
+
+    fn register_recv_stream(self: Rc, id: StreamId) -> RecvStream {
+        let mut map = self.streams.borrow_mut();
+        Self::register_recv_stream_locked(&self, &mut map, id)
+    }
+
+    fn register_recv_stream_locked(
+        this: &Rc,
+        map: &mut HashMap>,
+        id: StreamId,
+    ) -> RecvStream {
+        let (tx, rx) = mpsc::channel(this.limits.recv_event_queue_len);
+        map.insert(id, tx);
+        RecvStream::new(id, this.clone(), rx)
+    }
+
+    fn try_send_frame(&self, frame: Frame) -> std::result::Result<(), Error> {
+        let bytes = frame.encode().freeze();
+        self.outbound_tx
+            .borrow_mut()
+            .try_send(Message::Binary(bytes))
+            .map_err(|_| Error::Closed)
+    }
+
+    fn send_frame(&self, frame: Frame) -> Result<()> {
+        self.try_send_frame(frame)
+    }
+
+    async fn protocol_error(&self, code: u64, reason: &str) -> Result<()> {
+        let _ = self.try_send_frame(Frame::ConnectionClose {
+            code,
+            reason: reason.to_string(),
+        });
+        self.close_all().await;
+        Err(Error::Protocol(reason.to_string()))
+    }
+
+    async fn close_all(&self) {
+        if self
+            .closed
+            .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
+            .is_err()
+        {
+            return;
+        }
+
+        self.streams.borrow_mut().clear();
+        *self.accept_uni_tx.lock().await = None;
+        *self.accept_bi_tx.lock().await = None;
+    }
+}
diff --git a/websock-wasm/src/connection.rs b/websock-wasm/src/connection.rs
index 87f7e60..7af5f6d 100644
--- a/websock-wasm/src/connection.rs
+++ b/websock-wasm/src/connection.rs
@@ -58,21 +58,24 @@ pub async fn connect(url: &str, opts: ConnectOptions) -> Result {
     ws.set_onclose(Some(wait_onclose.as_ref().unchecked_ref()));
 
     // Wait until the connection is opened or fails.
-    match open_rx.await {
-        Ok(Ok(())) => {}
-        Ok(Err(e)) => return Err(e),
-        Err(_) => return Err(Error::Other("onopen waiter dropped".into())),
-    }
+    let open_res = open_rx.await;
 
-    // Unset the connection process handlers.
+    // Always unset the connection process handlers.
     ws.set_onopen(None);
     ws.set_onerror(None);
     ws.set_onclose(None);
 
+    // Drop closures AFTER unsetting.
     drop(wait_onopen);
     drop(wait_onerror);
     drop(wait_onclose);
 
+    match open_res {
+        Ok(Ok(())) => {}
+        Ok(Err(e)) => return Err(e),
+        Err(_) => return Err(Error::Other("onopen waiter dropped".into())),
+    }
+
     // Set up message/error/close handlers.
     let tx_msg = tx.clone();
     let onmessage =
diff --git a/websock/examples/echo-server.rs b/websock/examples/echo-server.rs
index 1a039e6..a009bff 100644
--- a/websock/examples/echo-server.rs
+++ b/websock/examples/echo-server.rs
@@ -36,7 +36,9 @@ async fn main() -> anyhow::Result<()> {
 
     let args = Args::parse();
 
-    let mut builder = ServerBuilder::new().with_addr(args.addr);
+    let mut builder = ServerBuilder::new()
+        .with_addr(args.addr)
+        .with_default_alpn();
 
     match (&args.tls_cert, &args.tls_key) {
         (Some(cert_path), Some(key_path)) => {